1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <algorithm>
19 #include <ostream>
20 #include <set>
21 #include <unordered_set>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/escaping.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/protobuf_util.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_computation.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_module.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/platform/human_readable_json.h"
52 #include "tensorflow/core/platform/logging.h"
53
54 namespace xla {
55
56 using absl::CEscape;
57 using absl::StrAppend;
58 using absl::StrCat;
59 using absl::StrJoin;
60
61 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)62 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
63 const HloInstructionProto& proto,
64 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
65 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
66 bool prohibit_empty_literal) {
67 TF_RET_CHECK(!proto.opcode().empty());
68 HloOpcode opcode;
69 auto opcode_or = StringToHloOpcode(proto.opcode());
70 absl::optional<ComparisonDirection> comparison_direction;
71 if (opcode_or.ok()) {
72 opcode = opcode_or.ConsumeValueOrDie();
73 } else {
74 // Unknown opcode. Try auto-upgrading deprecated "less-than",
75 // "greater-than", etc opcodes, which are now rolled into the kCompare
76 // opcode.
77 if (proto.opcode() == "equal-to") {
78 comparison_direction = ComparisonDirection::kEq;
79 } else if (proto.opcode() == "not-equal-to") {
80 comparison_direction = ComparisonDirection::kNe;
81 } else if (proto.opcode() == "greater-than-or-equal-to") {
82 comparison_direction = ComparisonDirection::kGe;
83 } else if (proto.opcode() == "greater-than") {
84 comparison_direction = ComparisonDirection::kGt;
85 } else if (proto.opcode() == "less-than-or-equal-to") {
86 comparison_direction = ComparisonDirection::kLe;
87 } else if (proto.opcode() == "less-than") {
88 comparison_direction = ComparisonDirection::kLt;
89 }
90 if (comparison_direction) {
91 opcode = HloOpcode::kCompare;
92 } else {
93 return InvalidArgument("Unknown opcode: %s", proto.opcode());
94 }
95 }
96
97 TF_RET_CHECK(proto.has_shape());
98
99 std::unique_ptr<HloInstruction> instruction;
100 const auto operands = [&instruction_map, &proto](int index) {
101 return instruction_map.at(proto.operand_ids(index));
102 };
103 const auto all_operands = [&instruction_map, &proto]() {
104 std::vector<HloInstruction*> result(proto.operand_ids_size());
105 std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
106 result.begin(), [&instruction_map](int64 operand_id) {
107 return instruction_map.at(operand_id);
108 });
109 return result;
110 };
111 const auto computations = [&computation_map, &proto](int index) {
112 return computation_map.at(proto.called_computation_ids(index));
113 };
114 const auto all_computations = [&computation_map, &proto]() {
115 std::vector<HloComputation*> result(proto.called_computation_ids_size());
116 std::transform(proto.called_computation_ids().begin(),
117 proto.called_computation_ids().end(), result.begin(),
118 [&computation_map](int64 computation_id) {
119 return computation_map.at(computation_id);
120 });
121 return result;
122 };
123
124 TF_RET_CHECK(
125 absl::c_all_of(proto.operand_ids(),
126 [&](int64 id) { return instruction_map.contains(id); }))
127 << proto.name() << " instruction contains invalid operand id(s)";
128
129 TF_RET_CHECK(
130 absl::c_all_of(proto.called_computation_ids(),
131 [&](int64 id) { return computation_map.contains(id); }))
132 << proto.name() << " instruction references invalid computation id(s)";
133
134 Shape shape(proto.shape());
135 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
136
137 absl::optional<int> arity = HloOpcodeArity(opcode);
138 if (arity) {
139 TF_RET_CHECK(proto.operand_ids_size() == *arity)
140 << proto.opcode() << " instruction should have " << *arity
141 << " operands but sees " << proto.operand_ids_size();
142 }
143
144 switch (opcode) {
145 // Ops migrated to subclasses.
146 case HloOpcode::kBatchNormTraining:
147 instruction =
148 CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
149 proto.epsilon(), proto.feature_index());
150 break;
151 case HloOpcode::kBatchNormInference:
152 instruction = CreateBatchNormInference(
153 shape, operands(0), operands(1), operands(2), operands(3),
154 operands(4), proto.epsilon(), proto.feature_index());
155 break;
156 case HloOpcode::kBatchNormGrad:
157 instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
158 operands(2), operands(3), operands(4),
159 proto.epsilon(), proto.feature_index());
160 break;
161 case HloOpcode::kFft: {
162 std::vector<int64> fft_length(proto.fft_length().begin(),
163 proto.fft_length().end());
164 instruction = CreateFft(shape, operands(0), proto.fft_type(),
165 absl::Span<const int64>(fft_length));
166 break;
167 }
168 case HloOpcode::kCompare: {
169 // Auto-upgraded from deprecated opcode skips the following.
170 if (!comparison_direction) {
171 TF_ASSIGN_OR_RETURN(
172 comparison_direction,
173 StringToComparisonDirection(proto.comparison_direction()));
174 }
175 instruction =
176 CreateCompare(shape, operands(0), operands(1), *comparison_direction);
177 break;
178 }
179 case HloOpcode::kTriangularSolve: {
180 instruction = CreateTriangularSolve(shape, operands(0), operands(1),
181 proto.triangular_solve_options());
182 break;
183 }
184 case HloOpcode::kCholesky: {
185 instruction =
186 CreateCholesky(shape, operands(0), proto.cholesky_options());
187 break;
188 }
189 case HloOpcode::kSend:
190 instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
191 proto.is_host_transfer());
192 break;
193 case HloOpcode::kSendDone:
194 instruction = CreateSendDone(operands(0), proto.is_host_transfer());
195 break;
196 case HloOpcode::kRecv:
197 instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
198 proto.channel_id(), proto.is_host_transfer());
199 break;
200 case HloOpcode::kRecvDone:
201 instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
202 break;
203 case HloOpcode::kReverse:
204 instruction = CreateReverse(shape, operands(0),
205 std::vector<int64>(proto.dimensions().begin(),
206 proto.dimensions().end()));
207 break;
208 case HloOpcode::kConcatenate:
209 TF_RET_CHECK(proto.dimensions_size() == 1)
210 << "Concatenate instruction should have 1 dimension but sees "
211 << proto.dimensions_size();
212 instruction =
213 CreateConcatenate(shape, all_operands(), proto.dimensions(0));
214 break;
215 case HloOpcode::kConditional: {
216 TF_RET_CHECK(proto.called_computation_ids_size() > 0)
217 << "conditional should have at least 1 called computation";
218 if (operands(0)->shape().element_type() == PRED) {
219 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
220 << "conditional should have exactly 2 called computations but got "
221 << proto.called_computation_ids_size();
222 }
223 TF_RET_CHECK(proto.operand_ids_size() ==
224 proto.called_computation_ids_size() + 1)
225 << "conditional should have one branch_index operand plus one "
226 "operand per called computation but got "
227 << proto.operand_ids_size() << " operands for "
228 << proto.called_computation_ids_size() << " branch computations";
229 auto cond_operands = all_operands();
230 instruction =
231 CreateConditional(shape, cond_operands[0], all_computations(),
232 absl::MakeSpan(cond_operands).subspan(1));
233 break;
234 }
235 case HloOpcode::kReduce:
236 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
237 << "Reduce instruction should have an even number of operands but "
238 "sees "
239 << proto.operand_ids_size();
240 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
241 << "Reduce instruction should have 1 called computation but sees "
242 << proto.called_computation_ids_size();
243 {
244 const auto reduce_operands = all_operands();
245 auto inputs = absl::MakeSpan(reduce_operands)
246 .subspan(0, reduce_operands.size() / 2);
247 auto init_values =
248 absl::MakeSpan(reduce_operands)
249 .subspan(reduce_operands.size() / 2, reduce_operands.size());
250 instruction =
251 CreateReduce(shape, inputs, init_values,
252 std::vector<int64>(proto.dimensions().begin(),
253 proto.dimensions().end()),
254 computations(0));
255 }
256 break;
257 case HloOpcode::kSort: {
258 TF_RET_CHECK(proto.operand_ids_size() >= 1)
259 << "Sort instruction should have at least 1 operand but has "
260 << proto.operand_ids_size();
261 TF_RET_CHECK(proto.dimensions().size() == 1)
262 << "Sort instruction should have 1 dimension";
263 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
264 << "Sort instruction should one called computation but sees "
265 << proto.called_computation_ids_size();
266 auto sort_operands = all_operands();
267 instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
268 computations(0), proto.is_stable());
269 break;
270 }
271 case HloOpcode::kTranspose:
272 instruction =
273 CreateTranspose(shape, operands(0),
274 std::vector<int64>(proto.dimensions().begin(),
275 proto.dimensions().end()));
276 break;
277 case HloOpcode::kBroadcast:
278 instruction =
279 CreateBroadcast(shape, operands(0),
280 std::vector<int64>(proto.dimensions().begin(),
281 proto.dimensions().end()));
282 break;
283 case HloOpcode::kMap:
284 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
285 << "Map instruction should have 1 called computation but sees "
286 << proto.called_computation_ids_size();
287 instruction = CreateMap(shape, all_operands(), computations(0));
288 break;
289 case HloOpcode::kSlice: {
290 std::vector<int64> slice_starts, slice_limits, slice_strides;
291 for (const HloInstructionProto::SliceDimensions& slice_dimensions :
292 proto.slice_dimensions()) {
293 slice_starts.push_back(slice_dimensions.start());
294 slice_limits.push_back(slice_dimensions.limit());
295 slice_strides.push_back(slice_dimensions.stride());
296 }
297 instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
298 slice_strides);
299 break;
300 }
301 case HloOpcode::kConstant: {
302 // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
303 if (proto.has_literal()) {
304 TF_ASSIGN_OR_RETURN(
305 auto literal,
306 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
307 instruction = CreateConstant(std::move(literal));
308 // Literal's shape may have no/different tiling info.
309 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
310 instruction->shape(), shape));
311 *instruction->mutable_shape() = shape;
312 } else {
313 instruction = absl::make_unique<HloConstantInstruction>(shape);
314 }
315 break;
316 }
317 case HloOpcode::kTrace: {
318 TF_RET_CHECK(proto.has_literal());
319 TF_ASSIGN_OR_RETURN(
320 auto literal,
321 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
322 instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
323 break;
324 }
325 case HloOpcode::kFusion: {
326 // In the proto, fused computations are held exclusively within the
327 // HloInstructionProto and do not appear as an HloComputationProto within
328 // the HloModuleProto.
329 TF_RET_CHECK(!proto.fusion_kind().empty());
330 TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
331 StringToFusionKind(proto.fusion_kind()));
332
333 // Find the fused computation and set its fusion instruction.
334 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
335 << "Expect 1 called computation for fusion instruction but sees "
336 << proto.called_computation_ids_size();
337 const int64 fusion_id = proto.called_computation_ids(0);
338 auto* fused_computation =
339 tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
340 TF_RET_CHECK(fused_computation != nullptr)
341 << "No fusion computation with id " << fusion_id;
342 instruction =
343 CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
344 break;
345 }
346 case HloOpcode::kRng:
347 instruction = CreateRng(shape, proto.distribution(), all_operands());
348 break;
349 case HloOpcode::kRngGetAndUpdateState:
350 instruction = CreateRngGetAndUpdateState(shape, proto.delta());
351 break;
352 case HloOpcode::kParameter:
353 instruction =
354 CreateParameter(proto.parameter_number(), shape, proto.name());
355 if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
356 instruction->set_parameter_replicated_at_leaf_buffers(
357 proto.parameter_replication().replicated_at_leaf_buffers());
358 }
359 break;
360 case HloOpcode::kGetTupleElement:
361 instruction =
362 CreateGetTupleElement(shape, operands(0), proto.tuple_index());
363 break;
364 case HloOpcode::kReducePrecision:
365 instruction = CreateReducePrecision(
366 shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
367 break;
368 case HloOpcode::kInfeed: {
369 TF_RET_CHECK(shape.IsTuple() &&
370 (ShapeUtil::TupleElementCount(shape) == 2))
371 << "Infeed should have a tuple shape with 2 operands, but has: "
372 << shape;
373 const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
374 instruction =
375 CreateInfeed(data_shape, operands(0), proto.infeed_config());
376 } break;
377 case HloOpcode::kOutfeed: {
378 Shape outfeed_shape(proto.outfeed_shape());
379 TF_RETURN_IF_ERROR(
380 ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
381 instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
382 proto.outfeed_config());
383 break;
384 }
385 case HloOpcode::kAllReduce: {
386 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
387 << "AllReduce should have 1 called computation but sees "
388 << proto.called_computation_ids_size();
389 TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
390 << "AllReduce cannot have both channel_id() and all_reduce_id()";
391 absl::optional<int64> channel_id;
392 if (proto.channel_id() > 0) {
393 channel_id = proto.channel_id();
394 }
395 if (proto.all_reduce_id() > 0) {
396 channel_id = proto.all_reduce_id();
397 }
398 instruction = CreateAllReduce(
399 shape, all_operands(), computations(0),
400 /*replica_groups=*/
401 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
402 proto.replica_groups().end()),
403 /*constrain_layout=*/proto.constrain_layout(),
404 /*channel_id=*/channel_id);
405 break;
406 }
407 case HloOpcode::kAllToAll: {
408 absl::optional<int64> channel_id;
409 if (proto.channel_id() > 0) {
410 channel_id = proto.channel_id();
411 }
412 absl::optional<int64> split_dimension;
413 if (proto.dimensions_size() > 0) {
414 TF_RET_CHECK(proto.dimensions_size() == 1)
415 << "AllToAll cannot have more than 1 dimension (split dimension)";
416 TF_RET_CHECK(all_operands().size() == 1)
417 << "AllToAll must have a single operand when the split dimension "
418 "is specified";
419 split_dimension = proto.dimensions(0);
420 }
421 instruction = CreateAllToAll(
422 shape, all_operands(),
423 /*replica_groups=*/
424 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
425 proto.replica_groups().end()),
426 /*channel_id=*/channel_id, split_dimension);
427 break;
428 }
429 case HloOpcode::kCollectivePermute: {
430 std::vector<std::pair<int64, int64>> source_target_pairs(
431 proto.source_target_pairs_size());
432 absl::optional<int64> channel_id;
433 if (proto.channel_id() > 0) {
434 channel_id = proto.channel_id();
435 }
436 for (int i = 0; i < source_target_pairs.size(); i++) {
437 source_target_pairs[i].first = proto.source_target_pairs(i).source();
438 source_target_pairs[i].second = proto.source_target_pairs(i).target();
439 }
440 instruction = CreateCollectivePermute(shape, operands(0),
441 source_target_pairs, channel_id);
442 break;
443 }
444 case HloOpcode::kReplicaId: {
445 instruction = CreateReplicaId();
446 break;
447 }
448 case HloOpcode::kPartitionId: {
449 instruction = CreatePartitionId();
450 break;
451 }
452 case HloOpcode::kConvolution: {
453 TF_RET_CHECK(proto.has_window());
454 TF_RET_CHECK(proto.has_convolution_dimension_numbers());
455 PrecisionConfig precision_config = proto.precision_config();
456 precision_config.mutable_operand_precision()->Resize(
457 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
458 instruction = CreateConvolve(
459 shape, operands(0), operands(1),
460 std::max<int64>(proto.feature_group_count(), 1),
461 std::max<int64>(proto.batch_group_count(), 1), proto.window(),
462 proto.convolution_dimension_numbers(), precision_config);
463 break;
464 }
465 case HloOpcode::kReduceWindow:
466 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
467 << "ReduceWindow should have 1 called computation but sees "
468 << proto.called_computation_ids_size();
469 instruction = CreateReduceWindow(shape, operands(0), operands(1),
470 proto.window(), computations(0));
471 break;
472 case HloOpcode::kSelectAndScatter:
473 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
474 << "SelectAndScatter should have 2 called computations but sees "
475 << proto.called_computation_ids_size();
476 instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
477 proto.window(), operands(1),
478 operands(2), computations(1));
479 break;
480 case HloOpcode::kCustomCall: {
481 if (proto.constrain_layout()) {
482 // A proto RepeatedPtrField cannot be converted to a Span (it is a
483 // vector of pointers essentially) so create a vector of shapes to pass
484 // in.
485 std::vector<Shape> operand_shapes;
486 for (const ShapeProto& shape_proto :
487 proto.operand_shapes_with_layout()) {
488 operand_shapes.emplace_back(shape_proto);
489 }
490 instruction =
491 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
492 operand_shapes, proto.backend_config());
493 } else {
494 instruction =
495 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
496 proto.backend_config());
497 }
498 auto custom_call_instr =
499 Cast<HloCustomCallInstruction>(instruction.get());
500 if (proto.has_window()) {
501 custom_call_instr->set_window(proto.window());
502 }
503 if (proto.has_convolution_dimension_numbers()) {
504 custom_call_instr->set_convolution_dimension_numbers(
505 proto.convolution_dimension_numbers());
506 }
507 custom_call_instr->set_feature_group_count(
508 std::max(static_cast<int64>(proto.feature_group_count()), int64{1}));
509 custom_call_instr->set_batch_group_count(
510 std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
511 custom_call_instr->set_custom_call_has_side_effect(
512 proto.custom_call_has_side_effect());
513 break;
514 }
515 case HloOpcode::kPad:
516 TF_RET_CHECK(proto.has_padding_config());
517 instruction =
518 CreatePad(shape, operands(0), operands(1), proto.padding_config());
519 break;
520 case HloOpcode::kDynamicSlice: {
521 std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
522 absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
523 TF_RET_CHECK(proto.operand_ids_size() >= 1)
524 << "DynamicSlice instruction should have at least 1 operands but "
525 "sees "
526 << proto.operand_ids_size();
527 // TODO(b/118437727): Old form, make the check unconditional.
528 if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
529 auto expected_operands = 1 + operands(0)->shape().rank();
530 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
531 << "DynamicSlice instruction should have " << expected_operands
532 << " operands, but has " << proto.operand_ids_size();
533 }
534 const auto& operand_vector = all_operands();
535 instruction = CreateDynamicSlice(
536 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
537 slice_sizes);
538 break;
539 }
540 case HloOpcode::kDynamicUpdateSlice: {
541 TF_RET_CHECK(proto.operand_ids_size() >= 2)
542 << "DynamicUpdateSlice instruction should have at least 2 operands "
543 "but sees "
544 << proto.operand_ids_size();
545 // TODO(b/118437727): Old form, make the check unconditional.
546 if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
547 auto expected_operands = 2 + operands(0)->shape().rank();
548 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
549 << "DynamicUpdateSlice instruction should have "
550 << expected_operands << " operands, but has "
551 << proto.operand_ids_size();
552 }
553 const auto& operand_vector = all_operands();
554 instruction =
555 CreateDynamicUpdateSlice(shape, operands(0), operands(1),
556 absl::MakeSpan(operand_vector).subspan(2));
557
558 break;
559 }
560 case HloOpcode::kGather: {
561 TF_RET_CHECK(proto.has_gather_dimension_numbers())
562 << "Gather instruction should have GatherDimensionNumbers set.";
563 auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
564 proto.gather_dimension_numbers());
565 std::vector<int64> gather_slice_sizes;
566 for (int64 bound : proto.gather_slice_sizes()) {
567 gather_slice_sizes.push_back(bound);
568 }
569 instruction = CreateGather(shape, operands(0), operands(1),
570 *gather_dimension_numbers, gather_slice_sizes,
571 proto.indices_are_sorted());
572 break;
573 }
574 case HloOpcode::kScatter: {
575 TF_RET_CHECK(proto.has_scatter_dimension_numbers())
576 << "Scatter instruction should have ScatterDimensionNumbers set.";
577 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
578 << "Scatter instruction should have 1 called computation but sees "
579 << proto.called_computation_ids_size();
580 auto scatter_dimension_numbers =
581 absl::make_unique<ScatterDimensionNumbers>(
582 proto.scatter_dimension_numbers());
583 instruction =
584 CreateScatter(shape, operands(0), operands(1), operands(2),
585 computations(0), *scatter_dimension_numbers,
586 proto.indices_are_sorted(), proto.unique_indices());
587 break;
588 }
589 case HloOpcode::kIota:
590 TF_RET_CHECK(proto.dimensions_size() == 1)
591 << "Iota instruction should have 1 dimension but sees "
592 << proto.dimensions_size();
593 instruction = CreateIota(shape, proto.dimensions(0));
594 break;
595 case HloOpcode::kDot: {
596 TF_RET_CHECK(proto.has_dot_dimension_numbers())
597 << "Dot instruction should have dot_dimension_numbers.";
598 PrecisionConfig precision_config = proto.precision_config();
599 precision_config.mutable_operand_precision()->Resize(
600 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
601 instruction = absl::make_unique<HloDotInstruction>(
602 shape, operands(0), operands(1), proto.dot_dimension_numbers(),
603 precision_config);
604 break;
605 }
606 case HloOpcode::kDomain: {
607 std::shared_ptr<const HloSharding> entry_hlo_sharding;
608 std::shared_ptr<const HloSharding> exit_hlo_sharding;
609 if (proto.has_domain_entry_sharding()) {
610 TF_ASSIGN_OR_RETURN(
611 HloSharding sharding,
612 HloSharding::FromProto(proto.domain_entry_sharding()));
613 entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
614 }
615 if (proto.has_domain_exit_sharding()) {
616 TF_ASSIGN_OR_RETURN(
617 HloSharding sharding,
618 HloSharding::FromProto(proto.domain_exit_sharding()));
619 exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
620 }
621 instruction = absl::make_unique<HloDomainInstruction>(
622 shape, operands(0),
623 absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
624 absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
625 break;
626 }
627 case HloOpcode::kGetDimensionSize:
628 TF_RET_CHECK(proto.dimensions_size() == 1);
629 instruction =
630 CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
631 break;
632 case HloOpcode::kSetDimensionSize:
633 TF_RET_CHECK(proto.dimensions_size() == 1);
634 instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
635 proto.dimensions(0));
636 break;
637 case HloOpcode::kReshape: {
638 int64 inferred_dimension = -1;
639 if (!proto.dimensions().empty()) {
640 inferred_dimension = proto.dimensions()[0];
641 }
642 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
643 ShapeUtil::ElementsIn(shape) ==
644 ShapeUtil::ElementsIn(operands(0)->shape()))
645 << "shape: " << ShapeUtil::HumanString(shape)
646 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
647 instruction = CreateReshape(shape, operands(0), inferred_dimension);
648 break;
649 }
650 default: {
651 instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
652 for (const int64 operand_id : proto.operand_ids()) {
653 instruction->AppendOperand(instruction_map.at(operand_id));
654 }
655 if (instruction->opcode() != HloOpcode::kFusion) {
656 if (instruction->opcode() == HloOpcode::kCall) {
657 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
658 << "Call should have 1 called computation but has "
659 << proto.called_computation_ids_size();
660 }
661 for (const int64 computation_id : proto.called_computation_ids()) {
662 instruction->called_computations_.push_back(
663 computation_map.at(computation_id));
664 }
665 }
666 TF_RET_CHECK(!proto.has_precision_config())
667 << instruction->opcode() << proto.DebugString();
668 TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
669 break;
670 }
671 }
672
673 for (const int64 predecessor_id : proto.control_predecessor_ids()) {
674 TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
675 << "No instruction with id " << predecessor_id;
676 TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
677 ->AddControlDependencyTo(instruction.get()));
678 }
679
680 TF_RET_CHECK(!proto.name().empty());
681 instruction->SetAndSanitizeName(proto.name());
682 instruction->metadata_ = proto.metadata();
683 instruction->backend_config_ = proto.backend_config();
684 instruction->outer_dimension_partitions_.assign(
685 proto.outer_dimension_partitions().begin(),
686 proto.outer_dimension_partitions().end());
687
688 TF_RET_CHECK(proto.id() >= 0)
689 << "Instruction with negative id: " << proto.id();
690 TF_RET_CHECK(proto.id() <= INT_MAX)
691 << "Instruction with id > INT_MAX: " << proto.id();
692 instruction->unique_id_ = proto.id();
693
694 if (proto.has_sharding()) {
695 TF_ASSIGN_OR_RETURN(const auto& sharding,
696 HloSharding::FromProto(proto.sharding()));
697 instruction->set_sharding(sharding);
698 }
699
700 if (proto.has_frontend_attributes()) {
701 instruction->set_frontend_attributes(proto.frontend_attributes());
702 }
703
704 return std::move(instruction);
705 }
706
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)707 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
708 int64 parameter_number, const Shape& shape, const string& name) {
709 return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
710 name);
711 }
712
CreateTrace(const string & tag,HloInstruction * operand)713 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
714 const string& tag, HloInstruction* operand) {
715 return absl::make_unique<HloTraceInstruction>(tag, operand);
716 }
717
CreateConstant(Literal literal)718 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
719 Literal literal) {
720 return absl::make_unique<HloConstantInstruction>(std::move(literal));
721 }
722
CreateIota(const Shape & shape,int64 iota_dimension)723 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
724 const Shape& shape, int64 iota_dimension) {
725 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
726 }
727
728 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)729 HloInstruction::CreateGetTupleElement(const Shape& shape,
730 HloInstruction* operand, int64 index) {
731 return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
732 index);
733 }
734
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)735 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
736 const Shape& shape, RandomDistribution distribution,
737 absl::Span<HloInstruction* const> parameters) {
738 return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
739 }
740
741 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64 delta)742 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) {
743 return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
744 }
745
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)746 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
747 const Shape& shape, HloOpcode opcode,
748 absl::Span<HloInstruction* const> operands) {
749 if (opcode == HloOpcode::kCopy) {
750 // It is impossible to copy an opaque shape, we don't know how big it is.
751 CHECK(!shape.IsOpaque());
752 }
753 auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
754 for (auto operand : operands) {
755 instruction->AppendOperand(operand);
756 }
757 return instruction;
758 }
759
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)760 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
761 const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
762 // Only certain opcodes are supported with CreateUnary: opcodes of unary
763 // instructions with no auxiliary fields.
764 switch (opcode) {
765 case HloOpcode::kAbs:
766 case HloOpcode::kRoundNearestAfz:
767 case HloOpcode::kBitcast:
768 case HloOpcode::kCeil:
769 case HloOpcode::kCopy:
770 case HloOpcode::kCopyStart:
771 case HloOpcode::kCopyDone:
772 case HloOpcode::kCos:
773 case HloOpcode::kClz:
774 case HloOpcode::kExp:
775 case HloOpcode::kExpm1:
776 case HloOpcode::kFloor:
777 case HloOpcode::kImag:
778 case HloOpcode::kIsFinite:
779 case HloOpcode::kLog:
780 case HloOpcode::kLog1p:
781 case HloOpcode::kNot:
782 case HloOpcode::kNegate:
783 case HloOpcode::kPopulationCount:
784 case HloOpcode::kReal:
785 case HloOpcode::kRsqrt:
786 case HloOpcode::kSign:
787 case HloOpcode::kSin:
788 case HloOpcode::kSqrt:
789 case HloOpcode::kTanh:
790 break;
791 default:
792 LOG(FATAL) << "Invalid unary instruction opcode "
793 << HloOpcodeString(opcode);
794 }
795 return CreateNary(shape, opcode, {operand});
796 }
797
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)798 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
799 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
800 HloInstruction* rhs) {
801 // Only certain opcodes are supported with CreateBinary: opcodes of binary
802 // instructions with no auxiliary fields.
803 switch (opcode) {
804 case HloOpcode::kAdd:
805 case HloOpcode::kAtan2:
806 case HloOpcode::kDivide:
807 case HloOpcode::kComplex:
808 case HloOpcode::kMaximum:
809 case HloOpcode::kMinimum:
810 case HloOpcode::kMultiply:
811 case HloOpcode::kPower:
812 case HloOpcode::kRemainder:
813 case HloOpcode::kSubtract:
814 case HloOpcode::kAnd:
815 case HloOpcode::kOr:
816 case HloOpcode::kXor:
817 case HloOpcode::kShiftLeft:
818 case HloOpcode::kShiftRightArithmetic:
819 case HloOpcode::kShiftRightLogical:
820 break;
821 default:
822 LOG(FATAL) << "Invalid binary instruction opcode "
823 << HloOpcodeString(opcode);
824 }
825 return CreateNary(shape, opcode, {lhs, rhs});
826 }
827
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)828 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
829 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
830 HloInstruction* rhs, HloInstruction* ehs) {
831 // Only certain opcodes are supported with CreateTernary: opcodes of ternary
832 // instructions with no auxiliary fields.
833 switch (opcode) {
834 case HloOpcode::kClamp:
835 case HloOpcode::kSelect:
836 case HloOpcode::kTupleSelect:
837 break;
838 default:
839 LOG(FATAL) << "Invalid ternary instruction opcode "
840 << HloOpcodeString(opcode);
841 }
842 return CreateNary(shape, opcode, {lhs, rhs, ehs});
843 }
844
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)845 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
846 const Shape& shape, HloOpcode opcode,
847 absl::Span<HloInstruction* const> operands) {
848 CHECK_EQ(HloOpcode::kTuple, opcode);
849 return CreateNary(shape, opcode, operands);
850 }
851
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)852 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
853 const Shape& shape, absl::Span<HloInstruction* const> operands,
854 HloComputation* map_computation) {
855 return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
856 }
857
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)858 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
859 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
860 int64 feature_group_count, int64 batch_group_count, const Window& window,
861 const ConvolutionDimensionNumbers& dimension_numbers,
862 const PrecisionConfig& precision_config) {
863 return absl::make_unique<HloConvolutionInstruction>(
864 shape, lhs, rhs, feature_group_count, batch_group_count, window,
865 dimension_numbers, precision_config);
866 }
867
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)868 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
869 const Shape& shape, HloInstruction* operand, FftType fft_type,
870 absl::Span<const int64> fft_length) {
871 return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
872 fft_length);
873 }
874
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)875 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
876 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
877 ComparisonDirection direction) {
878 return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
879 }
880
881 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)882 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
883 HloInstruction* b,
884 const TriangularSolveOptions& options) {
885 return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
886 }
887
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)888 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
889 const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
890 return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
891 }
892
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)893 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
894 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
895 const DotDimensionNumbers& dimension_numbers,
896 const PrecisionConfig& precision_config) {
897 return absl::make_unique<HloDotInstruction>(
898 shape, lhs, rhs, dimension_numbers, precision_config);
899 }
900
901 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)902 HloInstruction::CreateReducePrecision(const Shape& shape,
903 HloInstruction* operand,
904 const int exponent_bits,
905 const int mantissa_bits) {
906 return absl::make_unique<HloReducePrecisionInstruction>(
907 shape, operand, exponent_bits, mantissa_bits);
908 }
909
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id)910 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
911 const Shape& shape, absl::Span<HloInstruction* const> operands,
912 HloComputation* reduce_computation,
913 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
914 const absl::optional<int64>& channel_id) {
915 return absl::make_unique<HloAllReduceInstruction>(
916 shape, operands, reduce_computation, replica_groups, constrain_layout,
917 channel_id);
918 }
919
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)920 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
921 const Shape& shape, absl::Span<HloInstruction* const> operands,
922 const std::vector<ReplicaGroup>& replica_groups,
923 const absl::optional<int64>& channel_id,
924 const absl::optional<int64>& split_dimension) {
925 return absl::make_unique<HloAllToAllInstruction>(
926 shape, operands, replica_groups, channel_id, split_dimension);
927 }
928
929 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)930 HloInstruction::CreateCollectivePermute(
931 const Shape& shape, HloInstruction* operand,
932 const std::vector<std::pair<int64, int64>>& source_target_pairs,
933 const absl::optional<int64>& channel_id) {
934 return absl::make_unique<HloCollectivePermuteInstruction>(
935 shape, operand, source_target_pairs, channel_id);
936 }
937
CreateReplicaId()938 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
939 return absl::WrapUnique(
940 new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
941 }
942
943 /* static */ std::unique_ptr<HloInstruction>
CreatePartitionId()944 HloInstruction::CreatePartitionId() {
945 return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId,
946 ShapeUtil::MakeShape(U32, {})));
947 }
948
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)949 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
950 const Shape& infeed_shape, HloInstruction* token_operand,
951 const string& config) {
952 return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
953 config);
954 }
955
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)956 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
957 const Shape& outfeed_shape, HloInstruction* operand,
958 HloInstruction* token_operand, absl::string_view outfeed_config) {
959 return absl::make_unique<HloOutfeedInstruction>(
960 outfeed_shape, operand, token_operand, outfeed_config);
961 }
962
CreateSend(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)963 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
964 HloInstruction* operand, HloInstruction* token, int64 channel_id,
965 bool is_host_transfer) {
966 return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
967 is_host_transfer);
968 }
969
CreateSendDone(HloInstruction * operand,bool is_host_transfer)970 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
971 HloInstruction* operand, bool is_host_transfer) {
972 auto send_operand = DynCast<HloSendInstruction>(operand);
973 CHECK(send_operand != nullptr)
974 << "SendDone must take the context operand from Send";
975 return absl::make_unique<HloSendDoneInstruction>(send_operand,
976 is_host_transfer);
977 }
978
CreateRecv(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)979 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
980 const Shape& shape, HloInstruction* token, int64 channel_id,
981 bool is_host_transfer) {
982 return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
983 is_host_transfer);
984 }
985
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)986 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
987 HloInstruction* operand, bool is_host_transfer) {
988 auto recv_operand = DynCast<HloRecvInstruction>(operand);
989 CHECK(recv_operand != nullptr)
990 << "RecvDone must take the context operand from Recv";
991 return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
992 is_host_transfer);
993 }
994
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)995 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
996 const Shape& shape, HloInstruction* operand,
997 absl::Span<const int64> dimensions) {
998 return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
999 }
1000
CreateAfterAll(absl::Span<HloInstruction * const> operands)1001 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1002 absl::Span<HloInstruction* const> operands) {
1003 CHECK(!operands.empty());
1004 auto instruction = absl::WrapUnique(
1005 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1006 for (auto operand : operands) {
1007 instruction->AppendOperand(operand);
1008 }
1009 return instruction;
1010 }
1011
CreateToken()1012 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1013 return absl::WrapUnique(
1014 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1015 }
1016
1017 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1018 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1019 HloInstruction* token_operand) {
1020 auto instruction = absl::WrapUnique(
1021 new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1022 instruction->AppendOperand(data_operand);
1023 instruction->AppendOperand(token_operand);
1024 return instruction;
1025 }
1026
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1027 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1028 const Shape& shape, HloComputation* condition, HloComputation* body,
1029 HloInstruction* init) {
1030 auto instruction =
1031 absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1032 instruction->AppendOperand(init);
1033 // Body comes before condition computation in the vector.
1034 instruction->called_computations_.push_back(body);
1035 instruction->called_computations_.push_back(condition);
1036 return instruction;
1037 }
1038
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1039 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1040 const Shape& shape, HloInstruction* pred,
1041 HloInstruction* true_computation_arg, HloComputation* true_computation,
1042 HloInstruction* false_computation_arg, HloComputation* false_computation) {
1043 auto instruction =
1044 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1045 instruction->AppendOperand(pred);
1046 instruction->AppendOperand(true_computation_arg);
1047 instruction->AppendOperand(false_computation_arg);
1048 // In called_computations_, the index of true_computation must be 0 and that
1049 // of false computation must be 1, as defined by kTrueComputationIndex and
1050 // kFalseComputationIndex.
1051 instruction->called_computations_.push_back(true_computation);
1052 instruction->called_computations_.push_back(false_computation);
1053 return instruction;
1054 }
1055
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1056 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1057 const Shape& shape, HloInstruction* branch_index,
1058 absl::Span<HloComputation* const> branch_computations,
1059 absl::Span<HloInstruction* const> branch_computation_args) {
1060 auto instruction =
1061 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1062 instruction->AppendOperand(branch_index);
1063 CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1064 for (int i = 0; i < branch_computations.size(); ++i) {
1065 instruction->called_computations_.push_back(branch_computations[i]);
1066 instruction->AppendOperand(branch_computation_args[i]);
1067 }
1068 return instruction;
1069 }
1070
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1071 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1072 const Shape& shape, HloInstruction* operand,
1073 absl::Span<const int64> start_indices,
1074 absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
1075 return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1076 limit_indices, strides);
1077 }
1078
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)1079 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1080 const Shape& shape, HloInstruction* operand,
1081 absl::Span<HloInstruction* const> start_indices,
1082 absl::Span<const int64> slice_sizes) {
1083 return absl::make_unique<HloDynamicSliceInstruction>(
1084 shape, operand, start_indices, slice_sizes);
1085 }
1086
1087 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1088 HloInstruction::CreateDynamicUpdateSlice(
1089 const Shape& shape, HloInstruction* operand, HloInstruction* update,
1090 absl::Span<HloInstruction* const> start_indices) {
1091 return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1092 shape, operand, update, start_indices);
1093 }
1094
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)1095 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1096 const Shape& shape, absl::Span<HloInstruction* const> operands,
1097 int64 dimension) {
1098 return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1099 dimension);
1100 }
1101
CreateConvert(const Shape & shape,HloInstruction * operand)1102 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1103 const Shape& shape, HloInstruction* operand) {
1104 auto instruction =
1105 absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1106 instruction->AppendOperand(operand);
1107 return instruction;
1108 }
1109
1110 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1111 HloInstruction::CreateBitcastConvert(const Shape& shape,
1112 HloInstruction* operand) {
1113 auto instruction =
1114 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1115 instruction->AppendOperand(operand);
1116 return instruction;
1117 }
1118
CreateBitcast(const Shape & shape,HloInstruction * operand)1119 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1120 const Shape& shape, HloInstruction* operand) {
1121 auto instruction =
1122 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1123 instruction->AppendOperand(operand);
1124 return instruction;
1125 }
1126
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1127 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1128 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1129 absl::Span<const int64> dimensions_to_reduce,
1130 HloComputation* reduce_computation) {
1131 auto instruction = absl::WrapUnique(new HloReduceInstruction(
1132 shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1133 return std::move(instruction);
1134 }
1135
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1136 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1137 const Shape& shape, absl::Span<HloInstruction* const> operands,
1138 absl::Span<HloInstruction* const> init_values,
1139 absl::Span<const int64> dimensions_to_reduce,
1140 HloComputation* reduce_computation) {
1141 std::vector<HloInstruction*> all_args;
1142 all_args.reserve(operands.size() * 2);
1143 all_args.insert(all_args.end(), operands.begin(), operands.end());
1144 all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1145 return absl::make_unique<HloReduceInstruction>(
1146 shape, all_args, dimensions_to_reduce, reduce_computation);
1147 }
1148
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1149 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1150 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1151 const Window& window, HloComputation* reduce_computation) {
1152 return absl::make_unique<HloReduceWindowInstruction>(
1153 shape, operand, init_value, window, reduce_computation);
1154 }
1155
1156 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)1157 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1158 HloInstruction* operand,
1159 HloInstruction* scale,
1160 HloInstruction* offset, float epsilon,
1161 int64 feature_index) {
1162 return absl::make_unique<HloBatchNormTrainingInstruction>(
1163 shape, operand, scale, offset, epsilon, feature_index);
1164 }
1165
1166 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)1167 HloInstruction::CreateBatchNormInference(
1168 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1169 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1170 float epsilon, int64 feature_index) {
1171 return absl::make_unique<HloBatchNormInferenceInstruction>(
1172 shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1173 }
1174
1175 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)1176 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1177 HloInstruction* scale, HloInstruction* mean,
1178 HloInstruction* variance,
1179 HloInstruction* grad_output, float epsilon,
1180 int64 feature_index) {
1181 return absl::make_unique<HloBatchNormGradInstruction>(
1182 shape, operand, scale, mean, variance, grad_output, epsilon,
1183 feature_index);
1184 }
1185
1186 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1187 HloInstruction::CreateSelectAndScatter(
1188 const Shape& shape, HloInstruction* operand, HloComputation* select,
1189 const Window& window, HloInstruction* source, HloInstruction* init_value,
1190 HloComputation* scatter) {
1191 return absl::make_unique<HloSelectAndScatterInstruction>(
1192 shape, operand, select, window, source, init_value, scatter);
1193 }
1194
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1195 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1196 const Shape& shape, HloInstruction* operand,
1197 absl::Span<const int64> broadcast_dimensions) {
1198 return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1199 broadcast_dimensions);
1200 }
1201
1202 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64 dimension)1203 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1204 HloInstruction* operand,
1205 int64 dimension) {
1206 return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1207 dimension);
1208 }
1209
1210 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)1211 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1212 HloInstruction* operand,
1213 HloInstruction* val, int64 dimension) {
1214 return absl::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1215 dimension);
1216 }
1217
1218 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1219 HloInstruction::CreateBroadcastSequence(
1220 const Shape& output_shape, HloInstruction* operand,
1221 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1222 adder) {
1223 CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1224 operand->shape().rank() == output_shape.rank());
1225 Shape broadcast_shape = ShapeUtil::ChangeElementType(
1226 output_shape, operand->shape().element_type());
1227 // Do explicit broadcast for scalar.
1228 if (ShapeUtil::IsScalar(operand->shape())) {
1229 auto broadcast =
1230 HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1231 broadcast->set_metadata(operand->metadata());
1232 if (operand->has_sharding()) {
1233 broadcast->set_sharding(operand->sharding());
1234 }
1235 broadcast->set_frontend_attributes(operand->frontend_attributes());
1236 return broadcast;
1237 }
1238 // Do explicit broadcast for degenerate broadcast.
1239 std::vector<int64> broadcast_dimensions;
1240 std::vector<int64> reshaped_dimensions;
1241 for (int i = 0; i < operand->shape().rank(); i++) {
1242 if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1243 broadcast_dimensions.push_back(i);
1244 reshaped_dimensions.push_back(operand->shape().dimensions(i));
1245 } else {
1246 CHECK_EQ(operand->shape().dimensions(i), 1)
1247 << "An explicit broadcast sequence requires the broadcasted "
1248 "dimensions to be trivial; operand: "
1249 << operand->ToString() << "; output_shape: " << output_shape;
1250 }
1251 }
1252 // Eliminate the size one dimensions.
1253 HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1254 ShapeUtil::MakeShape(operand->shape().element_type(),
1255 reshaped_dimensions),
1256 operand));
1257 reshaped_operand->set_metadata(operand->metadata());
1258 if (operand->has_sharding()) {
1259 reshaped_operand->set_sharding(operand->sharding());
1260 }
1261 reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1262 // Broadcast 'reshape' up to the larger size.
1263 auto broadcast = HloInstruction::CreateBroadcast(
1264 broadcast_shape, reshaped_operand, broadcast_dimensions);
1265 broadcast->set_metadata(operand->metadata());
1266 if (operand->has_sharding()) {
1267 broadcast->set_sharding(operand->sharding());
1268 }
1269 broadcast->set_frontend_attributes(operand->frontend_attributes());
1270 return broadcast;
1271 }
1272
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1273 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1274 const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1275 const PaddingConfig& padding_config) {
1276 return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1277 padding_config);
1278 }
1279
CreateReshape(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)1280 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1281 const Shape& shape, HloInstruction* operand, int64 inferred_dimension) {
1282 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1283 ShapeUtil::ElementsIn(operand->shape()))
1284 << "shape: " << ShapeUtil::HumanString(shape)
1285 << " operand: " << ShapeUtil::HumanString(operand->shape());
1286
1287 return absl::make_unique<HloReshapeInstruction>(shape, operand,
1288 inferred_dimension);
1289 }
1290
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1291 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1292 const Shape& shape, HloInstruction* operand,
1293 absl::Span<const int64> dimensions) {
1294 return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1295 }
1296
CreateSort(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1297 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1298 const Shape& shape, int64 dimension,
1299 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1300 bool is_stable) {
1301 return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1302 compare, is_stable);
1303 }
1304
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1305 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1306 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1307 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1308 fused_root);
1309 }
1310
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1311 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1312 const Shape& shape, FusionKind fusion_kind,
1313 absl::Span<HloInstruction* const> operands,
1314 HloComputation* fusion_computation) {
1315 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1316 fusion_computation);
1317 }
1318
set_single_sharding(const HloSharding & sharding)1319 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1320 CHECK(!sharding.IsTuple()) << sharding;
1321 if (shape().IsTuple()) {
1322 set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1323 } else {
1324 set_sharding(sharding);
1325 }
1326 }
1327
SetupDerivedInstruction(HloInstruction * derived_instruction) const1328 void HloInstruction::SetupDerivedInstruction(
1329 HloInstruction* derived_instruction) const {
1330 if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
1331 shape_, derived_instruction->shape())) {
1332 // Only copy sharding if the shape of the two instruction is compatible
1333 // because copying it between differently shaped instructions can produce
1334 // invalid shardings.
1335 derived_instruction->set_sharding(*sharding_);
1336 } else {
1337 derived_instruction->clear_sharding();
1338 }
1339 derived_instruction->set_metadata(metadata_);
1340 derived_instruction->set_frontend_attributes(frontend_attributes_);
1341 }
1342
HasSideEffectNoRecurse() const1343 bool HloInstruction::HasSideEffectNoRecurse() const {
1344 switch (opcode_) {
1345 case HloOpcode::kSend:
1346 case HloOpcode::kSendDone:
1347 case HloOpcode::kRecv:
1348 case HloOpcode::kRecvDone:
1349 case HloOpcode::kRng:
1350 case HloOpcode::kRngGetAndUpdateState:
1351 case HloOpcode::kInfeed:
1352 case HloOpcode::kOutfeed:
1353 case HloOpcode::kTrace:
1354 return true;
1355 case HloOpcode::kAllReduce:
1356 return channel_id().has_value() ||
1357 Cast<HloAllReduceInstruction>(this)->constrain_layout();
1358 case HloOpcode::kCustomCall:
1359 return Cast<HloCustomCallInstruction>(this)
1360 ->custom_call_has_side_effect();
1361 default:
1362 return false;
1363 }
1364 }
1365
HasSideEffect() const1366 bool HloInstruction::HasSideEffect() const {
1367 if (HasSideEffectNoRecurse()) {
1368 return true;
1369 }
1370 // Check if any of the called computations has a side effect.
1371 for (const auto& computation : called_computations()) {
1372 if (computation->HasSideEffect()) {
1373 return true;
1374 }
1375 }
1376 return false;
1377 }
1378
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1379 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1380 const Shape& shape, absl::Span<HloInstruction* const> operands,
1381 HloComputation* computation) {
1382 std::unique_ptr<HloInstruction> instruction =
1383 absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1384 for (auto operand : operands) {
1385 instruction->AppendOperand(operand);
1386 }
1387 instruction->called_computations_.push_back(computation);
1388 return instruction;
1389 }
1390
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)1391 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1392 const Shape& shape, absl::Span<HloInstruction* const> operands,
1393 absl::string_view custom_call_target, string opaque) {
1394 return absl::make_unique<HloCustomCallInstruction>(
1395 shape, operands, custom_call_target, std::move(opaque));
1396 }
1397
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,string opaque)1398 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1399 const Shape& shape, absl::Span<HloInstruction* const> operands,
1400 absl::string_view custom_call_target,
1401 absl::Span<const Shape> operand_shapes_with_layout, string opaque) {
1402 return absl::make_unique<HloCustomCallInstruction>(
1403 shape, operands, custom_call_target, std::move(opaque),
1404 operand_shapes_with_layout);
1405 }
1406
CreateTuple(absl::Span<HloInstruction * const> elements)1407 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1408 absl::Span<HloInstruction* const> elements) {
1409 std::vector<Shape> element_shapes;
1410 for (auto element : elements) {
1411 element_shapes.push_back(element->shape());
1412 }
1413 Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1414 return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1415 }
1416
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1417 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1418 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1419 const GatherDimensionNumbers& gather_dim_numbers,
1420 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
1421 return absl::make_unique<HloGatherInstruction>(
1422 shape, operand, start_indices, gather_dim_numbers, slice_sizes,
1423 indices_are_sorted);
1424 }
1425
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1426 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1427 const Shape& shape, HloInstruction* operand,
1428 HloInstruction* scatter_indices, HloInstruction* updates,
1429 HloComputation* update_computation,
1430 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1431 bool unique_indices) {
1432 return absl::make_unique<HloScatterInstruction>(
1433 shape, operand, scatter_indices, updates, update_computation,
1434 scatter_dim_numbers, indices_are_sorted, unique_indices);
1435 }
1436
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1437 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1438 const Shape& shape, HloInstruction* operand,
1439 std::unique_ptr<DomainMetadata> operand_side_metadata,
1440 std::unique_ptr<DomainMetadata> user_side_metadata) {
1441 return absl::make_unique<HloDomainInstruction>(
1442 shape, operand, std::move(operand_side_metadata),
1443 std::move(user_side_metadata));
1444 }
1445
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1446 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1447 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1448 HloCloneContext* context) const {
1449 VLOG(3) << "CloneWithNewOperands:\n " << ToString();
1450 VLOG(3) << " new operands:";
1451 for (const HloInstruction* new_operand : new_operands) {
1452 VLOG(3) << " %" << new_operand->name();
1453 }
1454
1455 std::unique_ptr<HloInstruction> clone;
1456 // Explicitly call the factory for the instruction type. This is more robust
1457 // in the face of code changes than copying fields explicitly. This also
1458 // properly sets the user fields of the operands.
1459 switch (opcode_) {
1460 // Ops migrated to subclasses.
1461 // TODO(b/80131774): Remove this switch when migration is complete.
1462 case HloOpcode::kBatchNormTraining:
1463 case HloOpcode::kBatchNormInference:
1464 case HloOpcode::kBatchNormGrad:
1465 case HloOpcode::kFft:
1466 case HloOpcode::kCompare:
1467 case HloOpcode::kSend:
1468 case HloOpcode::kSendDone:
1469 case HloOpcode::kRecv:
1470 case HloOpcode::kRecvDone:
1471 case HloOpcode::kReverse:
1472 case HloOpcode::kConcatenate:
1473 case HloOpcode::kReduce:
1474 case HloOpcode::kTranspose:
1475 case HloOpcode::kBroadcast:
1476 case HloOpcode::kReshape:
1477 case HloOpcode::kMap:
1478 case HloOpcode::kSlice:
1479 case HloOpcode::kConstant:
1480 case HloOpcode::kTrace:
1481 case HloOpcode::kFusion:
1482 case HloOpcode::kRng:
1483 case HloOpcode::kRngGetAndUpdateState:
1484 case HloOpcode::kParameter:
1485 case HloOpcode::kGetTupleElement:
1486 case HloOpcode::kReducePrecision:
1487 case HloOpcode::kAllReduce:
1488 case HloOpcode::kAllToAll:
1489 case HloOpcode::kCollectivePermute:
1490 case HloOpcode::kInfeed:
1491 case HloOpcode::kOutfeed:
1492 case HloOpcode::kConvolution:
1493 case HloOpcode::kCustomCall:
1494 case HloOpcode::kReduceWindow:
1495 case HloOpcode::kSelectAndScatter:
1496 case HloOpcode::kPad:
1497 case HloOpcode::kDynamicSlice:
1498 case HloOpcode::kSort:
1499 case HloOpcode::kGather:
1500 case HloOpcode::kScatter:
1501 case HloOpcode::kIota:
1502 case HloOpcode::kDot:
1503 case HloOpcode::kDomain:
1504 case HloOpcode::kGetDimensionSize:
1505 case HloOpcode::kSetDimensionSize:
1506 case HloOpcode::kTriangularSolve:
1507 case HloOpcode::kCholesky:
1508 clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1509 break;
1510 // Unary ops.
1511 case HloOpcode::kAbs:
1512 case HloOpcode::kRoundNearestAfz:
1513 case HloOpcode::kBitcast:
1514 case HloOpcode::kCeil:
1515 case HloOpcode::kClz:
1516 case HloOpcode::kCopy:
1517 case HloOpcode::kCopyStart:
1518 case HloOpcode::kCopyDone:
1519 case HloOpcode::kCos:
1520 case HloOpcode::kExp:
1521 case HloOpcode::kExpm1:
1522 case HloOpcode::kImag:
1523 case HloOpcode::kIsFinite:
1524 case HloOpcode::kFloor:
1525 case HloOpcode::kLog:
1526 case HloOpcode::kLog1p:
1527 case HloOpcode::kNot:
1528 case HloOpcode::kNegate:
1529 case HloOpcode::kPopulationCount:
1530 case HloOpcode::kReal:
1531 case HloOpcode::kRsqrt:
1532 case HloOpcode::kSign:
1533 case HloOpcode::kSin:
1534 case HloOpcode::kSqrt:
1535 case HloOpcode::kTanh:
1536 CHECK_EQ(new_operands.size(), 1);
1537 clone = CreateUnary(shape, opcode_, new_operands[0]);
1538 break;
1539 // Binary ops.
1540 case HloOpcode::kAdd:
1541 case HloOpcode::kAtan2:
1542 case HloOpcode::kComplex:
1543 case HloOpcode::kDivide:
1544 case HloOpcode::kMultiply:
1545 case HloOpcode::kSubtract:
1546 case HloOpcode::kMaximum:
1547 case HloOpcode::kMinimum:
1548 case HloOpcode::kPower:
1549 case HloOpcode::kRemainder:
1550 case HloOpcode::kAnd:
1551 case HloOpcode::kOr:
1552 case HloOpcode::kXor:
1553 case HloOpcode::kShiftLeft:
1554 case HloOpcode::kShiftRightArithmetic:
1555 case HloOpcode::kShiftRightLogical:
1556 CHECK_EQ(new_operands.size(), 2);
1557 clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1558 break;
1559 // Ternary ops.
1560 case HloOpcode::kClamp:
1561 case HloOpcode::kSelect:
1562 case HloOpcode::kTupleSelect:
1563 CHECK_EQ(new_operands.size(), 3);
1564 clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1565 new_operands[2]);
1566 break;
1567 // Other supported ops.
1568 case HloOpcode::kCall:
1569 clone = CreateCall(shape, new_operands, to_apply());
1570 break;
1571 case HloOpcode::kConvert:
1572 CHECK_EQ(new_operands.size(), 1);
1573 clone = CreateConvert(shape, new_operands[0]);
1574 break;
1575 case HloOpcode::kBitcastConvert:
1576 CHECK_EQ(new_operands.size(), 1);
1577 clone = CreateBitcastConvert(shape, new_operands[0]);
1578 break;
1579 case HloOpcode::kDynamicUpdateSlice:
1580 clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1581 new_operands.subspan(2));
1582 break;
1583 case HloOpcode::kTuple:
1584 clone = CreateTuple(new_operands);
1585 *clone->mutable_shape() = shape;
1586 break;
1587 case HloOpcode::kWhile:
1588 CHECK_EQ(new_operands.size(), 1);
1589 clone =
1590 CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1591 break;
1592 case HloOpcode::kConditional:
1593 CHECK_EQ(new_operands.size(), branch_count() + 1);
1594 clone = CreateConditional(shape, new_operands[0],
1595 absl::MakeSpan(branch_computations()),
1596 new_operands.subspan(1));
1597 break;
1598 case HloOpcode::kAfterAll:
1599 if (new_operands.empty()) {
1600 clone = CreateToken();
1601 } else {
1602 clone = CreateAfterAll(new_operands);
1603 }
1604 break;
1605 case HloOpcode::kAddDependency:
1606 CHECK_EQ(new_operands.size(), 2);
1607 clone = CreateAddDependency(new_operands[0], new_operands[1]);
1608 break;
1609 case HloOpcode::kReplicaId:
1610 CHECK_EQ(new_operands.size(), 0);
1611 clone = CreateReplicaId();
1612 *clone->mutable_shape() = shape;
1613 break;
1614 case HloOpcode::kPartitionId:
1615 CHECK_EQ(new_operands.size(), 0);
1616 clone = CreatePartitionId();
1617 *clone->mutable_shape() = shape;
1618 break;
1619 }
1620 // SetupDerivedInstruction will setup the precision_config_ field.
1621 SetupDerivedInstruction(clone.get());
1622 clone->set_parent(parent_);
1623 clone->set_outer_dimension_partitions(outer_dimension_partitions_);
1624 clone->set_raw_backend_config_string(backend_config_);
1625 if (context != nullptr) {
1626 context->MapInstruction(this, clone.get());
1627 clone->ReplaceCalledComputations([&](HloComputation* callee) {
1628 return callee->parent() != context->module()
1629 ? context->module()->DeepCloneComputation(callee, context)
1630 : callee;
1631 });
1632 }
1633 return clone;
1634 }
1635
~HloInstruction()1636 HloInstruction::~HloInstruction() {
1637 // Detach from operands. An instruction may be repeated as an operand. To
1638 // avoid calling RemoveUser twice on the same operand, check before remove.
1639 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1640 HloInstruction* operand = operands_[operand_num];
1641 if (operand == nullptr) {
1642 continue;
1643 }
1644 if (operand->user_map_.find(this) != operand->user_map_.end()) {
1645 operand->RemoveUser(this);
1646 }
1647 operands_[operand_num] = nullptr;
1648 }
1649
1650 // Update users. Set `nullptr` to the corresponding operand slot for users.
1651 for (auto& user : this->users()) {
1652 for (int i = 0; i < user->operand_count(); ++i) {
1653 if (user->operands_[i] == this) {
1654 user->operands_[i] = nullptr;
1655 }
1656 }
1657 }
1658 }
1659
Clone(const string & suffix,HloCloneContext * context) const1660 std::unique_ptr<HloInstruction> HloInstruction::Clone(
1661 const string& suffix, HloCloneContext* context) const {
1662 std::unique_ptr<HloInstruction> clone =
1663 CloneWithNewOperands(shape_, operands_, context);
1664 if (suffix.empty()) {
1665 clone->name_ = name();
1666 } else {
1667 // If an instruction is cloned multiple times avoid names like
1668 // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1669 // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1670 // clone of foo.suffix2 is named foo.suffix3 and so on.
1671 const string dot_suffix = "." + suffix;
1672 size_t index = name().rfind(dot_suffix);
1673 if (index == string::npos) {
1674 // Existing name does not include ".suffix".
1675 clone->name_ = name() + dot_suffix;
1676 } else {
1677 // Existing name includes ".suffix". Determine if substring after
1678 // ".suffix" is numeric and should be replaced with an incremented number.
1679 string after_suffix = name().substr(index + dot_suffix.size());
1680 if (after_suffix.empty()) {
1681 // Existing name ends in ".suffix". New name should end in ".suffix2".
1682 clone->name_ = name() + "2";
1683 } else {
1684 // If names ends with .suffix[0-9]+ then replace with a suffix with the
1685 // numeric value incremented.
1686 int64 numeric_suffix;
1687 if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
1688 clone->name_ =
1689 StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1690 } else {
1691 // Substring after ".suffix" is non-numeric.
1692 clone->name_ = name() + dot_suffix;
1693 }
1694 }
1695 }
1696 }
1697 return clone;
1698 }
1699
1700 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1701 HloInstruction::LatestNonGteAncestorAndIndex() const {
1702 const HloInstruction* hlo = this;
1703 ShapeIndex index;
1704 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1705 index.push_back(hlo->tuple_index());
1706 hlo = hlo->operand(0);
1707 }
1708
1709 // We built up index in the reverse order from what we want.
1710 std::reverse(index.begin(), index.end());
1711
1712 return {hlo, index};
1713 }
1714
LatestNonGteAncestor() const1715 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1716 const HloInstruction* hlo = this;
1717 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1718 hlo = hlo->operand(0);
1719 }
1720 return hlo;
1721 }
1722
operand(int64 i) const1723 const HloInstruction* HloInstruction::operand(int64 i) const {
1724 return operands_.at(i);
1725 }
1726
mutable_operand(int64 i)1727 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1728 CHECK(operands_[i] != nullptr);
1729 return operands_.at(i);
1730 }
1731
operand_index(const HloInstruction * target) const1732 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1733 for (int64 i = 0; i < operand_count(); ++i) {
1734 if (target == operand(i)) {
1735 return i;
1736 }
1737 }
1738 LOG(FATAL) << "target was not an operand: " << target->ToString();
1739 }
1740
unique_operands() const1741 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
1742 InstructionVector unique;
1743 absl::flat_hash_set<const HloInstruction*> seen;
1744 for (HloInstruction* operand : operands()) {
1745 if (seen.insert(operand).second) {
1746 unique.push_back(operand);
1747 }
1748 }
1749 return unique;
1750 }
1751
AddControlDependencyTo(HloInstruction * instruction)1752 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1753 TF_RET_CHECK(instruction->parent() == parent());
1754 if (!absl::c_linear_search(control_successors_, instruction)) {
1755 control_successors_.push_back(instruction);
1756 TF_RET_CHECK(
1757 !absl::c_linear_search(instruction->control_predecessors_, this));
1758 instruction->control_predecessors_.push_back(this);
1759 }
1760 return Status::OK();
1761 }
1762
RemoveControlDependencyTo(HloInstruction * instruction)1763 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1764 TF_RET_CHECK(instruction->parent() == parent());
1765 TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
1766 TF_RETURN_IF_ERROR(
1767 EraseElementFromVector(&instruction->control_predecessors_, this));
1768 return Status::OK();
1769 }
1770
DropAllControlDeps()1771 Status HloInstruction::DropAllControlDeps() {
1772 for (auto* ctrl_succ : control_successors_) {
1773 TF_RETURN_IF_ERROR(
1774 EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
1775 }
1776 for (auto* ctrl_pred : control_predecessors_) {
1777 TF_RETURN_IF_ERROR(
1778 EraseElementFromVector(&ctrl_pred->control_successors_, this));
1779 }
1780 control_successors_.clear();
1781 control_predecessors_.clear();
1782 return Status::OK();
1783 }
1784
CopyAllControlDepsFrom(const HloInstruction * inst)1785 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
1786 for (auto* ctrl_pred : inst->control_predecessors()) {
1787 TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
1788 }
1789
1790 for (auto* ctrl_succ : inst->control_successors()) {
1791 TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
1792 }
1793
1794 return Status::OK();
1795 }
1796
AppendOperand(HloInstruction * operand)1797 void HloInstruction::AppendOperand(HloInstruction* operand) {
1798 operands_.push_back(operand);
1799 operand->AddUser(this);
1800 }
1801
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)1802 void HloInstruction::RemoveOperandsAtAscendingIndices(
1803 absl::Span<const int> ascending_indices) {
1804 if (ascending_indices.empty()) {
1805 return;
1806 }
1807 int next_index = 0;
1808 int removed_count = 0;
1809 for (int to_remove : ascending_indices) {
1810 while (next_index < to_remove) {
1811 operands_[next_index - removed_count] = operands_[next_index];
1812 ++next_index;
1813 }
1814 CHECK_LT(to_remove, operands_.size());
1815 ++removed_count;
1816 ++next_index;
1817 }
1818 while (next_index < operands_.size()) {
1819 operands_[next_index - removed_count] = operands_[next_index];
1820 ++next_index;
1821 }
1822 CHECK_EQ(removed_count, ascending_indices.size());
1823 operands_.resize(operands_.size() - removed_count);
1824 }
1825
AddUser(HloInstruction * user)1826 void HloInstruction::AddUser(HloInstruction* user) {
1827 if (!ContainsKey(user_map_, user)) {
1828 user_map_.emplace(user, users_.size());
1829 users_.push_back(user);
1830 }
1831 }
1832
UserId(HloInstruction * user)1833 int64 HloInstruction::UserId(HloInstruction* user) {
1834 auto result = user_map_.find(user);
1835 CHECK(result != user_map_.end());
1836 return result->second;
1837 }
1838
HasConstantOperand() const1839 bool HloInstruction::HasConstantOperand() const {
1840 for (const HloInstruction* operand : operands_) {
1841 if (operand->IsConstant()) {
1842 return true;
1843 }
1844 }
1845 return false;
1846 }
1847
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1848 bool HloInstruction::IdenticalSlowPath(
1849 const HloInstruction& other,
1850 const std::function<bool(const HloComputation*, const HloComputation*)>&
1851 eq_computations) const {
1852 // Perform opcode specific checks.
1853 switch (opcode()) {
1854 // The result of these instructions only depend upon their opcode and
1855 // operands.
1856 case HloOpcode::kAbs:
1857 case HloOpcode::kAtan2:
1858 case HloOpcode::kAdd:
1859 case HloOpcode::kBitcast:
1860 case HloOpcode::kBitcastConvert:
1861 case HloOpcode::kCeil:
1862 case HloOpcode::kClamp:
1863 case HloOpcode::kClz:
1864 case HloOpcode::kComplex:
1865 case HloOpcode::kConvert:
1866 case HloOpcode::kCopy:
1867 case HloOpcode::kCopyStart:
1868 case HloOpcode::kCopyDone:
1869 case HloOpcode::kCos:
1870 case HloOpcode::kDivide:
1871 case HloOpcode::kDynamicUpdateSlice:
1872 case HloOpcode::kExp:
1873 case HloOpcode::kExpm1:
1874 case HloOpcode::kFloor:
1875 case HloOpcode::kImag:
1876 case HloOpcode::kIsFinite:
1877 case HloOpcode::kLog:
1878 case HloOpcode::kLog1p:
1879 case HloOpcode::kAnd:
1880 case HloOpcode::kNot:
1881 case HloOpcode::kOr:
1882 case HloOpcode::kXor:
1883 case HloOpcode::kMaximum:
1884 case HloOpcode::kMinimum:
1885 case HloOpcode::kMultiply:
1886 case HloOpcode::kNegate:
1887 case HloOpcode::kPartitionId:
1888 case HloOpcode::kPopulationCount:
1889 case HloOpcode::kPower:
1890 case HloOpcode::kReal:
1891 case HloOpcode::kRemainder:
1892 case HloOpcode::kReshape:
1893 case HloOpcode::kReplicaId:
1894 case HloOpcode::kRoundNearestAfz:
1895 case HloOpcode::kRsqrt:
1896 case HloOpcode::kSelect:
1897 case HloOpcode::kShiftLeft:
1898 case HloOpcode::kShiftRightArithmetic:
1899 case HloOpcode::kShiftRightLogical:
1900 case HloOpcode::kSign:
1901 case HloOpcode::kSin:
1902 case HloOpcode::kSqrt:
1903 case HloOpcode::kSubtract:
1904 case HloOpcode::kTanh:
1905 case HloOpcode::kTuple:
1906 case HloOpcode::kTupleSelect:
1907 return true;
1908
1909 // This opcode has complex or special behavior so just return false.
1910 case HloOpcode::kAfterAll:
1911 case HloOpcode::kAddDependency:
1912 return false;
1913
1914 // Remaining instructions with special values.
1915 case HloOpcode::kCall:
1916 return eq_computations(to_apply(), other.to_apply());
1917 case HloOpcode::kConditional:
1918 for (int j = 0; j < branch_count(); ++j) {
1919 if (!eq_computations(branch_computation(j),
1920 other.branch_computation(j))) {
1921 return false;
1922 }
1923 }
1924 return true;
1925 case HloOpcode::kWhile:
1926 return (eq_computations(while_body(), other.while_body()) &&
1927 eq_computations(while_condition(), other.while_condition()));
1928
1929 // Ops migrated to subclasses should never come to this line.
1930 // TODO(b/80131774): Remove this switch when migration is complete.
1931 case HloOpcode::kBatchNormTraining:
1932 case HloOpcode::kBatchNormInference:
1933 case HloOpcode::kBatchNormGrad:
1934 case HloOpcode::kFft:
1935 case HloOpcode::kCompare:
1936 case HloOpcode::kSend:
1937 case HloOpcode::kSendDone:
1938 case HloOpcode::kRecv:
1939 case HloOpcode::kRecvDone:
1940 case HloOpcode::kReverse:
1941 case HloOpcode::kConcatenate:
1942 case HloOpcode::kReduce:
1943 case HloOpcode::kSort:
1944 case HloOpcode::kTranspose:
1945 case HloOpcode::kBroadcast:
1946 case HloOpcode::kMap:
1947 case HloOpcode::kSlice:
1948 case HloOpcode::kConstant:
1949 case HloOpcode::kIota:
1950 case HloOpcode::kTrace:
1951 case HloOpcode::kFusion:
1952 case HloOpcode::kRng:
1953 case HloOpcode::kRngGetAndUpdateState:
1954 case HloOpcode::kParameter:
1955 case HloOpcode::kGetTupleElement:
1956 case HloOpcode::kReducePrecision:
1957 case HloOpcode::kInfeed:
1958 case HloOpcode::kOutfeed:
1959 case HloOpcode::kAllReduce:
1960 case HloOpcode::kAllToAll:
1961 case HloOpcode::kCollectivePermute:
1962 case HloOpcode::kConvolution:
1963 case HloOpcode::kCustomCall:
1964 case HloOpcode::kReduceWindow:
1965 case HloOpcode::kSelectAndScatter:
1966 case HloOpcode::kPad:
1967 case HloOpcode::kDynamicSlice:
1968 case HloOpcode::kGather:
1969 case HloOpcode::kScatter:
1970 case HloOpcode::kDot:
1971 case HloOpcode::kDomain:
1972 case HloOpcode::kGetDimensionSize:
1973 case HloOpcode::kSetDimensionSize:
1974 case HloOpcode::kTriangularSolve:
1975 case HloOpcode::kCholesky:
1976 LOG(FATAL) << "Base class impl called for opcode with subclass: "
1977 << opcode();
1978 }
1979 return false;
1980 }
1981
HashOperand(const HloInstruction * hlo)1982 static uint64 HashOperand(const HloInstruction* hlo) {
1983 return ShapeUtil::Hash(hlo->shape());
1984 }
1985
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const1986 uint64 HloInstruction::Hash(
1987 const std::function<uint64(const HloInstruction*)>& hash_operand) const {
1988 using tensorflow::Hash64Combine;
1989
1990 uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
1991 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
1992
1993 if (!IsCrossModuleAllReduce()) {
1994 if (!operands().empty()) {
1995 for (size_t i = 0; i < operands().size(); ++i) {
1996 hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
1997 }
1998 }
1999 }
2000
2001 hash_value = Hash64Combine(hash_value, InnerHash());
2002 return hash_value;
2003 }
2004
Hash() const2005 uint64 HloInstruction::Hash() const {
2006 // Use HashOperand as an argument to prevent non-termination.
2007 return Hash(HashOperand);
2008 }
2009
InnerHash() const2010 uint64 HloInstruction::InnerHash() const { return 13; }
2011
RemoveUser(HloInstruction * user)2012 void HloInstruction::RemoveUser(HloInstruction* user) {
2013 auto map_it = user_map_.find(user);
2014 CHECK(map_it != user_map_.end());
2015
2016 const int64 index = map_it->second;
2017 CHECK_EQ(users_[index], user);
2018
2019 // Move the last user into the position of the removed user.
2020 users_[index] = users_.back();
2021 user_map_[users_.back()] = index;
2022
2023 // Remove the user from the map and drop the last slot from the vector what
2024 // have been moved to the position of the original user.
2025 user_map_.erase(map_it);
2026 users_.pop_back();
2027 }
2028
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2029 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2030 HloInstruction* new_producer) {
2031 TF_RET_CHECK(
2032 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2033 << "this shape: " << ShapeUtil::HumanString(shape())
2034 << ", replacement shape: "
2035 << ShapeUtil::HumanString(new_producer->shape());
2036 return ReplaceUseWithDifferentShape(user, new_producer);
2037 }
2038
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2039 Status HloInstruction::ReplaceUseWithDifferentShape(
2040 HloInstruction* user, HloInstruction* new_producer) {
2041 VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2042 << " with " << new_producer->name();
2043
2044 RemoveUser(user);
2045
2046 TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2047 std::replace(user->operands_.begin(), user->operands_.end(), this,
2048 new_producer);
2049 new_producer->AddUser(user);
2050 // Custom fusions may not be able to handle deduplicated operands.
2051 if (user->opcode() == HloOpcode::kFusion) {
2052 TF_RETURN_IF_ERROR(
2053 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2054 }
2055 return Status::OK();
2056 }
2057
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)2058 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
2059 HloInstruction* new_operand) {
2060 auto old_operand = operand(operand_num);
2061 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2062 new_operand->shape()))
2063 << old_operand->shape() << " is not compatible with "
2064 << new_operand->shape();
2065 return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2066 }
2067
ReplaceOperandWithDifferentShape(int64 operand_num,HloInstruction * new_operand)2068 Status HloInstruction::ReplaceOperandWithDifferentShape(
2069 int64 operand_num, HloInstruction* new_operand) {
2070 TF_RET_CHECK(operand_num >= 0);
2071 TF_RET_CHECK(operand_num < operand_count());
2072 HloInstruction* old_operand = mutable_operand(operand_num);
2073 if (old_operand == new_operand) {
2074 return Status::OK();
2075 }
2076
2077 operands_[operand_num] = new_operand;
2078
2079 VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2080 << new_operand->name() << ", was " << old_operand->name();
2081
2082 if (!absl::c_linear_search(operands_, old_operand)) {
2083 old_operand->RemoveUser(this);
2084 }
2085 new_operand->AddUser(this);
2086 return Status::OK();
2087 }
2088
ReplaceAllUsesWith(HloInstruction * new_producer)2089 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2090 TF_RET_CHECK(
2091 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2092 << shape() << " is not compatible with " << new_producer->shape();
2093 return ReplaceAllUsesWithDifferentShape(new_producer);
2094 }
2095
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2096 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2097 HloInstruction* new_producer) {
2098 bool new_producer_is_user = false;
2099 for (HloInstruction* user : users()) {
2100 if (user == new_producer) {
2101 // It's possible that new_producer is a user of this instruction as might
2102 // be the case when replacing an instruction with a kCopy of itself. In
2103 // this case, don't do the replacement to avoid creating a cycle in the
2104 // graph. new_producer remains the only user of this instruction.
2105 new_producer_is_user = true;
2106 } else {
2107 std::replace(user->operands_.begin(), user->operands_.end(), this,
2108 new_producer);
2109 new_producer->AddUser(user);
2110 if (user->opcode() == HloOpcode::kFusion) {
2111 TF_RETURN_IF_ERROR(
2112 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2113 }
2114 }
2115 }
2116 users_.clear();
2117 user_map_.clear();
2118 if (new_producer_is_user) {
2119 AddUser(new_producer);
2120 }
2121 if (parent_ && parent_->root_instruction() == this) {
2122 parent_->set_root_instruction(new_producer,
2123 /*accept_different_shape=*/true);
2124 }
2125
2126 return Status::OK();
2127 }
2128
to_apply() const2129 HloComputation* HloInstruction::to_apply() const {
2130 switch (opcode_) {
2131 case HloOpcode::kCall:
2132 case HloOpcode::kMap:
2133 case HloOpcode::kReduceWindow:
2134 case HloOpcode::kReduce:
2135 case HloOpcode::kAllReduce:
2136 case HloOpcode::kScatter:
2137 case HloOpcode::kSort:
2138 CHECK_EQ(called_computations_.size(), 1);
2139 return called_computations_[0];
2140 default:
2141 LOG(FATAL) << "Invalid opcode for to_apply(): "
2142 << HloOpcodeString(opcode());
2143 }
2144 }
2145
set_to_apply(HloComputation * computation)2146 void HloInstruction::set_to_apply(HloComputation* computation) {
2147 // Don't allow changing the computation for fused instructions so we don't
2148 // have to recompute called_instructions for the entire fusion instruction.
2149 CHECK(!IsFused());
2150 switch (opcode_) {
2151 case HloOpcode::kCall:
2152 case HloOpcode::kMap:
2153 case HloOpcode::kReduceWindow:
2154 case HloOpcode::kReduce:
2155 case HloOpcode::kAllReduce:
2156 case HloOpcode::kScatter:
2157 case HloOpcode::kSort:
2158 CHECK_EQ(called_computations_.size(), 1);
2159 called_computations_[0] = computation;
2160 break;
2161 default:
2162 LOG(FATAL) << "Invalid opcode for to_apply(): "
2163 << HloOpcodeString(opcode());
2164 }
2165 }
2166
while_condition() const2167 HloComputation* HloInstruction::while_condition() const {
2168 CHECK_EQ(HloOpcode::kWhile, opcode_);
2169 return called_computations_[kConditionComputationIndex];
2170 }
2171
while_body() const2172 HloComputation* HloInstruction::while_body() const {
2173 CHECK_EQ(HloOpcode::kWhile, opcode_);
2174 return called_computations_[kBodyComputationIndex];
2175 }
2176
set_while_condition(HloComputation * computation)2177 void HloInstruction::set_while_condition(HloComputation* computation) {
2178 // Don't allow changing the computation for fused instructions so we don't
2179 // have to recompute called_instructions for the entire fusion instruction.
2180 CHECK(!IsFused());
2181 CHECK_EQ(HloOpcode::kWhile, opcode_);
2182 called_computations_[kConditionComputationIndex] = computation;
2183 }
2184
set_while_body(HloComputation * computation)2185 void HloInstruction::set_while_body(HloComputation* computation) {
2186 // Don't allow changing the computation for fused instructions so we don't
2187 // have to recompute called_instructions for the entire fusion instruction.
2188 CHECK(!IsFused());
2189 CHECK_EQ(HloOpcode::kWhile, opcode_);
2190 called_computations_[kBodyComputationIndex] = computation;
2191 }
2192
while_init() const2193 HloInstruction* HloInstruction::while_init() const {
2194 CHECK_EQ(HloOpcode::kWhile, opcode_);
2195 return operands_[0];
2196 }
2197
true_computation() const2198 HloComputation* HloInstruction::true_computation() const {
2199 CHECK_EQ(HloOpcode::kConditional, opcode_);
2200 CHECK_EQ(PRED, operand(0)->shape().element_type());
2201 return called_computations_[kTrueComputationIndex];
2202 }
2203
false_computation() const2204 HloComputation* HloInstruction::false_computation() const {
2205 CHECK_EQ(HloOpcode::kConditional, opcode_);
2206 CHECK_EQ(PRED, operand(0)->shape().element_type());
2207 return called_computations_[kFalseComputationIndex];
2208 }
2209
branch_computations() const2210 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2211 const {
2212 CHECK(HloOpcode::kConditional == opcode_);
2213 return called_computations_;
2214 }
2215
branch_count() const2216 int HloInstruction::branch_count() const {
2217 CHECK(HloOpcode::kConditional == opcode_);
2218 return called_computations_.size();
2219 }
2220
branch_computation(int b) const2221 HloComputation* HloInstruction::branch_computation(int b) const {
2222 CHECK(HloOpcode::kConditional == opcode_);
2223 CHECK_GE(b, 0);
2224 CHECK_LT(b, called_computations_.size());
2225 return called_computations_[b];
2226 }
2227
set_branch_computation(int b,HloComputation * computation)2228 void HloInstruction::set_branch_computation(int b,
2229 HloComputation* computation) {
2230 // Don't allow changing the computation for fused instructions so we don't
2231 // have to recompute called_instructions for the entire fusion instruction.
2232 CHECK(!IsFused());
2233 CHECK_EQ(HloOpcode::kConditional, opcode_);
2234 called_computations_[b] = computation;
2235 }
2236
SignatureString() const2237 string HloInstruction::SignatureString() const {
2238 string operands =
2239 StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2240 StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2241 });
2242 return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2243 }
2244
PrintName(const string & name,bool print_ids)2245 string PrintName(const string& name, bool print_ids) {
2246 if (print_ids) {
2247 return name;
2248 } else {
2249 auto dot_position = name.find_first_of(".");
2250 return name.substr(0, dot_position);
2251 }
2252 }
2253
2254 namespace {
2255
2256 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2257
PrintNameInternal(const string & name,const HloPrintOptions & options)2258 string PrintNameInternal(const string& name, const HloPrintOptions& options) {
2259 return StrCat(options.print_percent() ? "%" : "",
2260 PrintName(name, options.print_ids()));
2261 }
2262
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2263 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2264 // This set contains HloInstructions from the top of `DFSStack` that might
2265 // belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
2266 // `subgraph` := {child,...,top}.
2267 absl::flat_hash_set<const HloInstruction*> subgraph;
2268 while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2269 subgraph.insert(dfs_stack->back().second);
2270 dfs_stack->pop_back();
2271 }
2272 // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2273 absl::flat_hash_set<const HloInstruction*> visited;
2274 absl::InlinedVector<const HloInstruction*, 16> dfs;
2275 dfs.push_back(child);
2276 while (!dfs.empty()) {
2277 bool found_next_instr = false;
2278 for (const auto& user : dfs.back()->users()) {
2279 if (user == child) {
2280 dfs.push_back(child);
2281 LOG(INFO) << "\n\nDirected cycle:\n "
2282 << absl::StrJoin(
2283 dfs, "\n ",
2284 [](std::string* out, const HloInstruction* instr) {
2285 out->append(instr->name());
2286 });
2287 return;
2288 }
2289 if (!subgraph.contains(user) || visited.contains(user)) {
2290 continue;
2291 }
2292 visited.insert(user);
2293 dfs.push_back(user);
2294 found_next_instr = true;
2295 }
2296 if (!found_next_instr) {
2297 dfs.pop_back();
2298 }
2299 }
2300 }
2301
2302 } // namespace
2303
ToString(const HloPrintOptions & options) const2304 string HloInstruction::ToString(const HloPrintOptions& options) const {
2305 CanonicalNameMap new_map;
2306 return ToStringWithCanonicalNameMap(options, &new_map);
2307 }
2308
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2309 bool HloInstruction::IsElementwiseImpl(
2310 const absl::optional<int64>& operand_idx) const {
2311 switch (opcode_) {
2312 // Unary elementwise operations.
2313 case HloOpcode::kAbs:
2314 case HloOpcode::kRoundNearestAfz:
2315 case HloOpcode::kCeil:
2316 case HloOpcode::kClz:
2317 case HloOpcode::kConvert:
2318 case HloOpcode::kBitcastConvert:
2319 case HloOpcode::kCopy:
2320 case HloOpcode::kCos:
2321 case HloOpcode::kExp:
2322 case HloOpcode::kExpm1:
2323 case HloOpcode::kFloor:
2324 case HloOpcode::kImag:
2325 case HloOpcode::kIsFinite:
2326 case HloOpcode::kLog:
2327 case HloOpcode::kLog1p:
2328 case HloOpcode::kNot:
2329 case HloOpcode::kNegate:
2330 case HloOpcode::kPopulationCount:
2331 case HloOpcode::kReal:
2332 case HloOpcode::kReducePrecision:
2333 case HloOpcode::kRsqrt:
2334 case HloOpcode::kSign:
2335 case HloOpcode::kSin:
2336 case HloOpcode::kSqrt:
2337 case HloOpcode::kTanh:
2338 CHECK_EQ(1, operand_count());
2339 return true;
2340
2341 // Binary elementwise operations, the same as in IsElementwiseBinary().
2342 case HloOpcode::kAdd:
2343 case HloOpcode::kAtan2:
2344 case HloOpcode::kCompare:
2345 case HloOpcode::kComplex:
2346 case HloOpcode::kDivide:
2347 case HloOpcode::kMaximum:
2348 case HloOpcode::kMinimum:
2349 case HloOpcode::kMultiply:
2350 case HloOpcode::kPower:
2351 case HloOpcode::kRemainder:
2352 case HloOpcode::kSubtract:
2353 case HloOpcode::kAnd:
2354 case HloOpcode::kOr:
2355 case HloOpcode::kXor:
2356 case HloOpcode::kShiftLeft:
2357 case HloOpcode::kShiftRightArithmetic:
2358 case HloOpcode::kShiftRightLogical:
2359 CHECK_EQ(2, operand_count());
2360 return true;
2361
2362 // Ternary elementwise operations.
2363 case HloOpcode::kSelect:
2364 case HloOpcode::kClamp:
2365 return true;
2366
2367 case HloOpcode::kDynamicUpdateSlice:
2368 return operand_idx.has_value() && operand_idx.value() == 0;
2369
2370 default:
2371 return false;
2372 }
2373 }
2374
IsCrossModuleAllReduce() const2375 bool HloInstruction::IsCrossModuleAllReduce() const {
2376 return opcode() == HloOpcode::kAllReduce && channel_id();
2377 }
2378
IsCrossReplicaAllReduce() const2379 bool HloInstruction::IsCrossReplicaAllReduce() const {
2380 return opcode() == HloOpcode::kAllReduce && !channel_id();
2381 }
2382
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2383 string HloInstruction::ToStringWithCanonicalNameMap(
2384 const HloPrintOptions& options,
2385 CanonicalNameMap* canonical_name_map) const {
2386 string result = "";
2387
2388 // Logic to print the instruction name (e.g. "%foo = ").
2389 if (options.canonicalize_instruction_names()) {
2390 if (options.is_in_nested_computation()) {
2391 // If we are canonicalizing instruction names and this is a top-level
2392 // HloInstruction::ToString() call, don't print an instruction name.
2393 StrAppend(&result,
2394 PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
2395 options),
2396 " = ");
2397 }
2398 } else {
2399 StrAppend(&result, PrintNameInternal(name(), options), " = ");
2400 }
2401
2402 // Print shape.
2403 if (options.include_layout_in_shapes()) {
2404 StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()));
2405 } else {
2406 StrAppend(&result, ShapeUtil::HumanString(shape()));
2407 }
2408
2409 // Print opcode, operand(s).
2410 StrAppend(&result, " ", HloOpcodeString(opcode()), "(",
2411 OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2412 ")");
2413
2414 // Print additional attributes. If an instruction contains a subcomputation,
2415 // the subcomputation is also printed here.
2416 for (const string& extra : ExtraAttributesToString(options)) {
2417 StrAppend(&result, ", ", extra);
2418 }
2419
2420 if (options.print_metadata() &&
2421 (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2422 !metadata_.source_file().empty())) {
2423 StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2424 }
2425 if (options.print_backend_config() && !backend_config_.empty()) {
2426 StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2427 }
2428 return result;
2429 }
2430
OperandsToString(const HloPrintOptions & options) const2431 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2432 CanonicalNameMap new_map;
2433 return OperandsToStringWithCanonicalNameMap(options, &new_map);
2434 }
2435
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2436 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2437 const HloPrintOptions& options,
2438 CanonicalNameMap* canonical_name_map) const {
2439 string operands;
2440 absl::Span<HloInstruction* const> slice(operands_);
2441 const int64 kMaxOperandsToShowIfCompact = 4;
2442 if (options.compact_operands() &&
2443 slice.size() > kMaxOperandsToShowIfCompact) {
2444 slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2445 }
2446 operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
2447 // If operand is already been deleted, put `null` to the string output.
2448 if (operand == nullptr) {
2449 StrAppend(out, "null ");
2450 return;
2451 }
2452 std::vector<string> str;
2453 if (options.print_operand_shape()) {
2454 if (options.include_layout_in_shapes()) {
2455 str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2456 } else {
2457 str.push_back(ShapeUtil::HumanString(operand->shape()));
2458 }
2459 }
2460
2461 // In a top-level HloInstruction::ToString() call, the operand name is not
2462 // part of the canonical string.
2463 if (options.canonicalize_instruction_names() &&
2464 options.is_in_nested_computation()) {
2465 str.push_back(PrintNameInternal(
2466 canonical_name_map->LookupOrInsert(operand->name()), options));
2467 } else if (options.print_operand_names()) {
2468 str.push_back(PrintNameInternal(operand->name(), options));
2469 }
2470 StrAppend(out, StrJoin(str, " "));
2471 });
2472 const int64 remaining = operands_.size() - slice.size();
2473 if (slice.size() != operands_.size()) {
2474 StrAppend(&operands, ", ...(+", remaining, ")");
2475 }
2476 return operands;
2477 }
2478
ExtraAttributesToString(const HloPrintOptions & options) const2479 std::vector<string> HloInstruction::ExtraAttributesToString(
2480 const HloPrintOptions& options) const {
2481 std::vector<string> extra = ExtraAttributesToStringImpl(options);
2482
2483 if (options.print_subcomputation_mode() ==
2484 HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
2485 if (opcode() == HloOpcode::kWhile) {
2486 extra.push_back(StrCat(
2487 "condition=", PrintNameInternal(while_condition()->name(), options)));
2488 extra.push_back(
2489 StrCat("body=", PrintNameInternal(while_body()->name(), options)));
2490 } else if (opcode() == HloOpcode::kSelectAndScatter) {
2491 extra.push_back(
2492 StrCat("select=", PrintNameInternal(select()->name(), options)));
2493 extra.push_back(
2494 StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
2495 } else if (opcode() == HloOpcode::kConditional) {
2496 if (operand(0)->shape().element_type() == PRED) {
2497 extra.push_back(
2498 StrCat("true_computation=",
2499 PrintNameInternal(true_computation()->name(), options)));
2500 extra.push_back(
2501 StrCat("false_computation=",
2502 PrintNameInternal(false_computation()->name(), options)));
2503 } else {
2504 extra.push_back(StrCat(
2505 "branch_computations={",
2506 StrJoin(branch_computations(), ", ",
2507 [&](string* out, const HloComputation* computation) {
2508 StrAppend(
2509 out, PrintNameInternal(computation->name(), options));
2510 }),
2511 "}"));
2512 }
2513 } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2514 opcode() == HloOpcode::kReduceWindow ||
2515 opcode() == HloOpcode::kReduce ||
2516 opcode() == HloOpcode::kAllReduce ||
2517 opcode() == HloOpcode::kScatter ||
2518 opcode() == HloOpcode::kSort) {
2519 extra.push_back(
2520 StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
2521 } else if (!called_computations().empty()) {
2522 extra.push_back(StrCat(
2523 "calls=",
2524 StrJoin(called_computations(), ", ",
2525 [&](string* out, const HloComputation* computation) {
2526 StrAppend(out,
2527 PrintNameInternal(computation->name(), options));
2528 })));
2529 }
2530 } else if (options.print_subcomputation_mode() ==
2531 HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
2532 HloPrintOptions new_options = options;
2533 new_options.set_is_in_nested_computation(true);
2534 switch (opcode()) {
2535 case HloOpcode::kWhile:
2536 extra.push_back(
2537 StrCat("condition=\n", while_condition()->ToString(new_options)));
2538 extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
2539 break;
2540 case HloOpcode::kSelectAndScatter:
2541 extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
2542 extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
2543 break;
2544 case HloOpcode::kConditional:
2545 if (operand(0)->shape().element_type() == PRED) {
2546 extra.push_back(StrCat("true_computation=\n",
2547 true_computation()->ToString(new_options)));
2548 extra.push_back(StrCat("false_computation=\n",
2549 false_computation()->ToString(new_options)));
2550 } else {
2551 extra.push_back(StrCat(
2552 "branch_computations={\n",
2553 StrJoin(branch_computations(), ",\n",
2554 [&](string* out, const HloComputation* computation) {
2555 StrAppend(out, computation->ToString(new_options));
2556 }),
2557 "\n}"));
2558 }
2559 break;
2560 case HloOpcode::kCall:
2561 case HloOpcode::kMap:
2562 case HloOpcode::kReduceWindow:
2563 case HloOpcode::kReduce:
2564 case HloOpcode::kAllReduce:
2565 case HloOpcode::kScatter:
2566 case HloOpcode::kSort:
2567 extra.push_back(
2568 StrCat("to_apply=\n", to_apply()->ToString(new_options)));
2569 break;
2570 default:
2571 if (!called_computations().empty()) {
2572 extra.push_back(StrCat(
2573 "calls=\n",
2574 StrJoin(called_computations(), ", ",
2575 [&](string* out, const HloComputation* computation) {
2576 StrAppend(out, computation->ToString(new_options));
2577 })));
2578 }
2579 break;
2580 }
2581 }
2582
2583 if (has_sharding()) {
2584 extra.push_back(StrCat("sharding=", sharding().ToString()));
2585 }
2586 if (!frontend_attributes_.map().empty()) {
2587 extra.push_back(StrCat("frontend_attributes=",
2588 FrontendAttributesToString(frontend_attributes_)));
2589 }
2590 if (!outer_dimension_partitions_.empty()) {
2591 extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
2592 StrJoin(outer_dimension_partitions_, ",")));
2593 }
2594
2595 if (options.print_control_dependencies() && !control_predecessors_.empty()) {
2596 extra.push_back(StrCat("control-predecessors={",
2597 StrJoin(control_predecessors_, ", ",
2598 [&](string* out, HloInstruction* pre) {
2599 StrAppend(out, PrintNameInternal(
2600 pre->name(), options));
2601 }),
2602 "}"));
2603 }
2604
2605 return extra;
2606 }
2607
ToShortString() const2608 string HloInstruction::ToShortString() const {
2609 return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2610 StrJoin(operands_, ", ",
2611 [](string* out, HloInstruction* operand) {
2612 StrAppend(out, "%", operand->name());
2613 }),
2614 ")");
2615 }
2616
ToProto() const2617 HloInstructionProto HloInstruction::ToProto() const {
2618 HloInstructionProto proto;
2619 CHECK(unique_id_ != -1)
2620 << "This instruction does not have a valid id. Please make sure the "
2621 "instruction is inside a module before dumping it.";
2622 proto.set_id(unique_id_);
2623 proto.set_name(name_);
2624 proto.set_opcode(HloOpcodeString(opcode_));
2625 *proto.mutable_shape() = shape_.ToProto();
2626 for (const HloInstruction* operand : operands_) {
2627 proto.add_operand_ids(operand->unique_id());
2628 }
2629 for (const HloInstruction* control : control_predecessors_) {
2630 proto.add_control_predecessor_ids(control->unique_id());
2631 }
2632
2633 *proto.mutable_metadata() = metadata_;
2634 proto.set_backend_config(backend_config_);
2635 if (opcode() != HloOpcode::kFusion) {
2636 for (const HloComputation* computation : called_computations_) {
2637 proto.add_called_computation_ids(computation->unique_id());
2638 }
2639 }
2640
2641 if (has_sharding()) {
2642 *proto.mutable_sharding() = sharding().ToProto();
2643 }
2644 if (!outer_dimension_partitions_.empty()) {
2645 for (const auto& idx : outer_dimension_partitions_) {
2646 proto.mutable_outer_dimension_partitions()->Add(idx);
2647 }
2648 }
2649
2650 *proto.mutable_frontend_attributes() = frontend_attributes_;
2651
2652 return proto;
2653 }
2654
ToCategory() const2655 string HloInstruction::ToCategory() const {
2656 if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2657 opcode() == HloOpcode::kReshape) {
2658 return "data formatting";
2659 }
2660
2661 if (IsElementwise()) {
2662 return "non-fusion elementwise";
2663 }
2664
2665 return HloOpcodeString(opcode());
2666 }
2667
tracing() const2668 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2669
set_tracing(HloInstruction * trace_instruction)2670 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2671 trace_instruction_ = trace_instruction;
2672 }
2673
IsFused() const2674 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2675
IsInputFusion() const2676 bool HloInstruction::IsInputFusion() const {
2677 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
2678 }
2679
IsLoopFusion() const2680 bool HloInstruction::IsLoopFusion() const {
2681 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
2682 }
2683
IsOutputFusion() const2684 bool HloInstruction::IsOutputFusion() const {
2685 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
2686 }
2687
IsCustomFusion() const2688 bool HloInstruction::IsCustomFusion() const {
2689 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
2690 }
2691
IsFusible() const2692 bool HloInstruction::IsFusible() const {
2693 // Instructions which are traced should not be fused.
2694 if (tracing()) {
2695 return false;
2696 }
2697 // Some kinds of instructions don't make sense to fuse.
2698 switch (opcode_) {
2699 case HloOpcode::kDomain:
2700 case HloOpcode::kParameter:
2701 case HloOpcode::kWhile:
2702 case HloOpcode::kConditional:
2703 case HloOpcode::kCall:
2704 return false;
2705 // Fusions are always fusible.
2706 case HloOpcode::kFusion:
2707 // Side effecting reduce and reduce window would be invalid HLO.
2708 case HloOpcode::kMap:
2709 case HloOpcode::kReduce:
2710 case HloOpcode::kReduceWindow:
2711 return true;
2712 // Side effecting instructions cannot be fused.
2713 default:
2714 return !HasSideEffect();
2715 }
2716 }
2717
HloInstruction(HloOpcode opcode,const Shape & shape)2718 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
2719 : unique_id_(-1),
2720 opcode_(opcode),
2721 shape_(shape),
2722 name_(HloOpcodeString(opcode)) {
2723 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2724 }
2725
2726 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)2727 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
2728 switch (opcode_) {
2729 case HloOpcode::kAbs:
2730 return visitor->HandleAbs(this);
2731 case HloOpcode::kAtan2:
2732 return visitor->HandleAtan2(this);
2733 case HloOpcode::kRoundNearestAfz:
2734 return visitor->HandleRound(this);
2735 case HloOpcode::kBatchNormTraining:
2736 return visitor->HandleBatchNormTraining(this);
2737 case HloOpcode::kBatchNormInference:
2738 return visitor->HandleBatchNormInference(this);
2739 case HloOpcode::kBatchNormGrad:
2740 return visitor->HandleBatchNormGrad(this);
2741 case HloOpcode::kSign:
2742 return visitor->HandleSign(this);
2743 case HloOpcode::kConstant:
2744 return visitor->HandleConstant(this);
2745 case HloOpcode::kGetTupleElement:
2746 return visitor->HandleGetTupleElement(this);
2747 case HloOpcode::kParameter:
2748 return visitor->HandleParameter(this);
2749 case HloOpcode::kCompare:
2750 return visitor->HandleCompare(this);
2751 case HloOpcode::kComplex:
2752 return visitor->HandleComplex(this);
2753 case HloOpcode::kAdd:
2754 return visitor->HandleAdd(this);
2755 case HloOpcode::kDivide:
2756 return visitor->HandleDivide(this);
2757 case HloOpcode::kSubtract:
2758 return visitor->HandleSubtract(this);
2759 case HloOpcode::kMaximum:
2760 return visitor->HandleMaximum(this);
2761 case HloOpcode::kMinimum:
2762 return visitor->HandleMinimum(this);
2763 case HloOpcode::kAnd:
2764 return visitor->HandleAnd(this);
2765 case HloOpcode::kOr:
2766 return visitor->HandleOr(this);
2767 case HloOpcode::kXor:
2768 return visitor->HandleXor(this);
2769 case HloOpcode::kShiftLeft:
2770 return visitor->HandleShiftLeft(this);
2771 case HloOpcode::kShiftRightArithmetic:
2772 return visitor->HandleShiftRightArithmetic(this);
2773 case HloOpcode::kShiftRightLogical:
2774 return visitor->HandleShiftRightLogical(this);
2775 case HloOpcode::kConcatenate:
2776 return visitor->HandleConcatenate(this);
2777 case HloOpcode::kConvert:
2778 return visitor->HandleConvert(this);
2779 case HloOpcode::kBitcastConvert:
2780 return visitor->HandleBitcastConvert(this);
2781 case HloOpcode::kCopy:
2782 return visitor->HandleCopy(this);
2783 case HloOpcode::kMultiply:
2784 return visitor->HandleMultiply(this);
2785 case HloOpcode::kDot:
2786 return visitor->HandleDot(this);
2787 case HloOpcode::kPower:
2788 return visitor->HandlePower(this);
2789 case HloOpcode::kRemainder:
2790 return visitor->HandleRemainder(this);
2791 case HloOpcode::kSelect:
2792 return visitor->HandleSelect(this);
2793 case HloOpcode::kTupleSelect:
2794 return visitor->HandleTupleSelect(this);
2795 case HloOpcode::kConvolution:
2796 return visitor->HandleConvolution(this);
2797 case HloOpcode::kFft:
2798 return visitor->HandleFft(this);
2799 case HloOpcode::kAllReduce:
2800 return visitor->HandleAllReduce(this);
2801 case HloOpcode::kAllToAll:
2802 return visitor->HandleAllToAll(this);
2803 case HloOpcode::kCollectivePermute:
2804 return visitor->HandleCollectivePermute(this);
2805 case HloOpcode::kReplicaId:
2806 return visitor->HandleReplicaId(this);
2807 case HloOpcode::kPartitionId:
2808 return visitor->HandlePartitionId(this);
2809 case HloOpcode::kTuple:
2810 return visitor->HandleTuple(this);
2811 case HloOpcode::kMap:
2812 return visitor->HandleMap(this);
2813 case HloOpcode::kClamp:
2814 return visitor->HandleClamp(this);
2815 case HloOpcode::kReduce:
2816 return visitor->HandleReduce(this);
2817 case HloOpcode::kReduceWindow:
2818 return visitor->HandleReduceWindow(this);
2819 case HloOpcode::kSelectAndScatter:
2820 return visitor->HandleSelectAndScatter(this);
2821 case HloOpcode::kNegate:
2822 return visitor->HandleNegate(this);
2823 case HloOpcode::kExp:
2824 return visitor->HandleExp(this);
2825 case HloOpcode::kExpm1:
2826 return visitor->HandleExpm1(this);
2827 case HloOpcode::kFloor:
2828 return visitor->HandleFloor(this);
2829 case HloOpcode::kCeil:
2830 return visitor->HandleCeil(this);
2831 case HloOpcode::kClz:
2832 return visitor->HandleClz(this);
2833 case HloOpcode::kLog:
2834 return visitor->HandleLog(this);
2835 case HloOpcode::kLog1p:
2836 return visitor->HandleLog1p(this);
2837 case HloOpcode::kTanh:
2838 return visitor->HandleTanh(this);
2839 case HloOpcode::kCos:
2840 return visitor->HandleCos(this);
2841 case HloOpcode::kSin:
2842 return visitor->HandleSin(this);
2843 case HloOpcode::kSqrt:
2844 return visitor->HandleSqrt(this);
2845 case HloOpcode::kRsqrt:
2846 return visitor->HandleRsqrt(this);
2847 case HloOpcode::kReal:
2848 return visitor->HandleReal(this);
2849 case HloOpcode::kImag:
2850 return visitor->HandleImag(this);
2851 case HloOpcode::kIsFinite:
2852 return visitor->HandleIsFinite(this);
2853 case HloOpcode::kNot:
2854 return visitor->HandleNot(this);
2855 case HloOpcode::kPopulationCount:
2856 return visitor->HandlePopulationCount(this);
2857 case HloOpcode::kBitcast:
2858 return visitor->HandleBitcast(this);
2859 case HloOpcode::kBroadcast:
2860 return visitor->HandleBroadcast(this);
2861 case HloOpcode::kPad:
2862 return visitor->HandlePad(this);
2863 case HloOpcode::kReshape:
2864 return visitor->HandleReshape(this);
2865 case HloOpcode::kTranspose:
2866 return visitor->HandleTranspose(this);
2867 case HloOpcode::kReverse:
2868 return visitor->HandleReverse(this);
2869 case HloOpcode::kReducePrecision:
2870 return visitor->HandleReducePrecision(this);
2871 case HloOpcode::kSlice:
2872 return visitor->HandleSlice(this);
2873 case HloOpcode::kDynamicSlice:
2874 return visitor->HandleDynamicSlice(this);
2875 case HloOpcode::kDynamicUpdateSlice:
2876 return visitor->HandleDynamicUpdateSlice(this);
2877 case HloOpcode::kSort:
2878 return visitor->HandleSort(this);
2879 case HloOpcode::kInfeed:
2880 return visitor->HandleInfeed(this);
2881 case HloOpcode::kOutfeed:
2882 return visitor->HandleOutfeed(this);
2883 case HloOpcode::kRng:
2884 return visitor->HandleRng(this);
2885 case HloOpcode::kRngGetAndUpdateState:
2886 return visitor->HandleRngGetAndUpdateState(this);
2887 case HloOpcode::kWhile:
2888 return visitor->HandleWhile(this);
2889 case HloOpcode::kFusion:
2890 return visitor->HandleFusion(this);
2891 case HloOpcode::kCall:
2892 return visitor->HandleCall(this);
2893 case HloOpcode::kConditional:
2894 return visitor->HandleConditional(this);
2895 case HloOpcode::kCustomCall:
2896 return visitor->HandleCustomCall(this);
2897 case HloOpcode::kCopyStart:
2898 return visitor->HandleCopyStart(this);
2899 case HloOpcode::kCopyDone:
2900 return visitor->HandleCopyDone(this);
2901 case HloOpcode::kRecv:
2902 return visitor->HandleRecv(this);
2903 case HloOpcode::kRecvDone:
2904 return visitor->HandleRecvDone(this);
2905 case HloOpcode::kSend:
2906 return visitor->HandleSend(this);
2907 case HloOpcode::kSendDone:
2908 return visitor->HandleSendDone(this);
2909 case HloOpcode::kGather:
2910 return visitor->HandleGather(this);
2911 case HloOpcode::kScatter:
2912 return visitor->HandleScatter(this);
2913 case HloOpcode::kDomain:
2914 return visitor->HandleDomain(this);
2915 case HloOpcode::kAfterAll:
2916 return visitor->HandleAfterAll(this);
2917 case HloOpcode::kAddDependency:
2918 return visitor->HandleAddDependency(this);
2919 case HloOpcode::kIota:
2920 return visitor->HandleIota(this);
2921 case HloOpcode::kGetDimensionSize:
2922 return visitor->HandleGetDimensionSize(this);
2923 case HloOpcode::kSetDimensionSize:
2924 return visitor->HandleSetDimensionSize(this);
2925 case HloOpcode::kTriangularSolve:
2926 return visitor->HandleTriangularSolve(this);
2927 case HloOpcode::kCholesky:
2928 return visitor->HandleCholesky(this);
2929
2930 // These opcodes are not handled here.
2931 case HloOpcode::kTrace:
2932 return Status::OK();
2933 }
2934 return InternalError(
2935 "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
2936 "please file a bug for XLA.",
2937 HloOpcodeString(opcode_));
2938 }
2939
2940 // Explicit instantiations.
2941 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2942 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2943
2944 // Push "child" onto the dfs_stack if not already visited. Returns false if a
2945 // cycle was detected, and true otherwise.
2946 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)2947 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
2948 HloInstruction* child) {
2949 CHECK(child != nullptr);
2950 const int id = child->unique_id();
2951 CHECK_GE(id, 0) << "instruction may not have a parent computation";
2952 switch (visitor->GetVisitState(id)) {
2953 case Visitor::kVisiting:
2954 return false;
2955
2956 case Visitor::kVisited:
2957 // Nothing to do
2958 return true;
2959
2960 case Visitor::kNotVisited:
2961 dfs_stack->push_back(std::make_pair(id, child));
2962 return true;
2963 }
2964 }
2965
2966 using InternalCompareFunction =
2967 std::function<bool(std::pair<int, const HloInstruction*>,
2968 std::pair<int, const HloInstruction*>)>;
2969 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)2970 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
2971 const InternalCompareFunction* operand_order,
2972 bool ignore_control_predecessors) {
2973 // Calculating the instruction count within a module can be expensive on large
2974 // models so only do it if the visit state is empty. This will help when the
2975 // same visitor is reused across many computations of a single module.
2976 if (visitor->VisitStateCapacity() == 0) {
2977 visitor->ReserveVisitStates(root->GetModule()->instruction_count());
2978 }
2979
2980 // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
2981 //
2982 // We need to keep track of both the id and the instruction because
2983 // instructions can get deleted while they are on the stack, so we
2984 // can't always use the (potentially dead) instruction object to grab
2985 // its id.
2986 DFSStack dfs_stack;
2987 dfs_stack.emplace_back(root->unique_id(), root);
2988
2989 do {
2990 DCHECK(!dfs_stack.empty());
2991
2992 int current_id = dfs_stack.back().first;
2993 HloInstruction* current_node = dfs_stack.back().second;
2994 CHECK_GE(current_id, 0) << current_id << ": " << current_node
2995 << ": instruction may not have parent computation";
2996 typename Visitor::VisitState visit_state =
2997 visitor->GetVisitState(current_id);
2998 if (visit_state == Visitor::kVisited) {
2999 dfs_stack.pop_back();
3000 VLOG(3) << "Not visiting HLO (id = " << current_id
3001 << ") as it was already visited.";
3002 continue;
3003 }
3004
3005 if (visit_state == Visitor::kVisiting) {
3006 dfs_stack.pop_back();
3007
3008 TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3009 VLOG(2) << "Visiting HLO %" << current_node->name();
3010 TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3011 visitor->SetVisitState(current_id, Visitor::kVisited);
3012 TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3013 continue;
3014 }
3015
3016 visitor->SetVisitState(current_id, Visitor::kVisiting);
3017
3018 const size_t old_dfs_stack_size = dfs_stack.size();
3019 for (HloInstruction* child : current_node->operands()) {
3020 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3021 PrintCycle(child, &dfs_stack);
3022 return FailedPrecondition(
3023 "A cycle is detected while visiting instruction %s",
3024 current_node->ToString());
3025 }
3026 }
3027
3028 if (!ignore_control_predecessors) {
3029 for (HloInstruction* child : current_node->control_predecessors()) {
3030 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3031 PrintCycle(child, &dfs_stack);
3032 return FailedPrecondition(
3033 "A cycle is detected while visiting instruction %s",
3034 current_node->ToString());
3035 }
3036 }
3037 }
3038
3039 if (operand_order != nullptr) {
3040 std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3041 *operand_order);
3042 }
3043
3044 // This makes the traversal order the same as what you'd expect
3045 // out of a recursive algorithm.
3046 std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3047 } while (!dfs_stack.empty());
3048
3049 return Status::OK();
3050 }
3051
3052 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3053 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3054 bool call_finish_visit,
3055 bool ignore_control_predecessors) {
3056 VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3057 TF_RETURN_IF_ERROR(
3058 PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3059 if (call_finish_visit) {
3060 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3061 }
3062 return Status::OK();
3063 }
3064
3065 // Explicit instantiations.
3066 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3067 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3068
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3069 Status HloInstruction::AcceptWithOperandOrder(
3070 DfsHloVisitor* visitor, const CompareFunction& operand_order,
3071 bool call_finish_visit) {
3072 VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3073 InternalCompareFunction func = [&operand_order](
3074 std::pair<int, const HloInstruction*> a,
3075 std::pair<int, const HloInstruction*> b) {
3076 // Call the client's comparison function on the actual HloInstruction*
3077 // objects (ignoring the internal ids we also have in our stack entries)
3078 return operand_order(a.second, b.second);
3079 };
3080 TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3081 /*ignore_control_predecessors=*/false));
3082 if (call_finish_visit) {
3083 VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3084 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3085 VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3086 }
3087 VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3088 return Status::OK();
3089 }
3090
shape() const3091 const Shape& HloInstruction::shape() const { return shape_; }
3092
OperandIndices(const HloInstruction * operand) const3093 absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
3094 const HloInstruction* operand) const {
3095 absl::InlinedVector<int64, 4> result;
3096 for (int64 i = 0; i < operand_count(); ++i) {
3097 if (this->operand(i) == operand) {
3098 result.push_back(i);
3099 }
3100 }
3101 return result;
3102 }
3103
IsElementwiseBinary() const3104 bool HloInstruction::IsElementwiseBinary() const {
3105 return IsElementwise() && operand_count() == 2;
3106 }
3107
IsElementwise() const3108 bool HloInstruction::IsElementwise() const {
3109 return IsElementwiseImpl(absl::nullopt);
3110 }
3111
IsElementwiseOnOperand(int64 operand_idx) const3112 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
3113 return IsElementwiseImpl(operand_idx);
3114 }
3115
3116 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3117 // in HloInstruction::OperandElementUse below.
3118 class HloInstruction::FusionReusesParamElements {
3119 public:
3120 using UseKind = HloInstruction::UseKind;
3121
3122 // We could rather iterate backwards through fused_instructions_ here, as it
3123 // is in reverse postorder, and compute whether each fused instruction reuses
3124 // the value of this parameter, which would save stack space but not allow us
3125 // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)3126 static UseKind Compute(int64 i, const HloInstruction& hlo) {
3127 absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3128 return ComputeInternal(i, hlo, &memoization_cache);
3129 }
3130
3131 private:
ComputeInternal(int64 i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3132 static UseKind ComputeInternal(
3133 int64 i, const HloInstruction& hlo,
3134 absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3135 if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3136 if (hlo_param->parameter_number() == i) {
3137 return UseKind::kUse;
3138 }
3139 }
3140
3141 auto p = cache->emplace(&hlo, UseKind::kNoUse);
3142 auto value_it = p.first;
3143 const bool key_is_new = p.second;
3144
3145 if (key_is_new) {
3146 for (int64 j = 0; j < hlo.operands_.size(); ++j) {
3147 UseKind old_val = value_it->second;
3148
3149 // The next operation invalidates iterators.
3150 UseKind new_val =
3151 Fold(old_val,
3152 FoldUseMandatory(hlo.OperandElementUse(j),
3153 ComputeInternal(i, *hlo.operand(j), cache)));
3154
3155 // Re-acquire the iterator. We could work harder to do this only if
3156 // absolutely necessary, but this code is not hot enough to warrant
3157 // that.
3158 value_it = cache->find(&hlo);
3159 value_it->second = new_val;
3160 }
3161 }
3162 return value_it->second;
3163 }
3164
3165 // Combines two UseKinds.
3166 //
3167 // This is the min operation on the lattice
3168 //
3169 // kReuse < kUse < kNoUse.
3170 //
3171 // Two kUses uses which have different permutations count as kReuse.
Fold(UseKind a,UseKind b)3172 static UseKind Fold(UseKind a, UseKind b) {
3173 // Without loss of generality, let `b` be the operation with the larger use
3174 // kind.
3175 if (b.kind < a.kind) {
3176 std::swap(a, b);
3177 }
3178 // If the kinds are different, return the smaller one, namely `a`.
3179 if (a.kind != b.kind) {
3180 return a;
3181 }
3182 // If the kinds are both kUse, check that they're the same permutation.
3183 if (a.kind == UseKind::kUse && b.kind == UseKind::kUse &&
3184 a.permutation_instr != b.permutation_instr) {
3185 return UseKind::kReuse;
3186 }
3187 return a; // They're the same.
3188 }
3189
3190 // Combines two UseKinds differently than Fold().
3191 //
3192 // This is the min operation on the lattice
3193 //
3194 // kNoUse < kReuse < kUse.
3195 //
3196 // If `a` and `b` are both kUse and one has a non-null permutation
3197 // instruction, returns kUse with that permutation. OTOH if both have
3198 // different, non-null permutation instructions, returns kReuse.
3199 //
3200 // You can think of this sort of as a conjunction, whereas Fold is sort of a
3201 // disjunction. FoldUseMandatory() says "no use" if either input isn't used,
3202 // whereas Fold() would say "use".
FoldUseMandatory(UseKind a,UseKind b)3203 static UseKind FoldUseMandatory(UseKind a, UseKind b) {
3204 if (a.kind == UseKind::kNoUse || b.kind == UseKind::kNoUse) {
3205 return UseKind::kNoUse;
3206 }
3207 if (a.kind == UseKind::kReuse || b.kind == UseKind::kReuse) {
3208 return UseKind::kReuse;
3209 }
3210 if (a.permutation_instr == b.permutation_instr) {
3211 return a; // They're the same.
3212 }
3213 if (b.permutation_instr == nullptr) {
3214 return a;
3215 }
3216 if (a.permutation_instr == nullptr) {
3217 return b;
3218 }
3219 return UseKind::kReuse;
3220 }
3221 };
3222
OperandElementUse(int64 operand_num) const3223 HloInstruction::UseKind HloInstruction::OperandElementUse(
3224 int64 operand_num) const {
3225 switch (opcode_) {
3226 case HloOpcode::kBitcast:
3227 // A bitcast that only adds or removes degenerate (i.e. size 1) dimensions
3228 // doesn't permute its elements, so it counts as a plain, non-permuting
3229 // use.
3230 return ShapeUtil::DropDegenerateDimensions(shape()) ==
3231 ShapeUtil::DropDegenerateDimensions(operand(0)->shape())
3232 ? UseKind::kUse
3233 : UseKind::Permuting(this);
3234 case HloOpcode::kConcatenate:
3235 case HloOpcode::kReshape:
3236 case HloOpcode::kReverse:
3237 case HloOpcode::kSlice:
3238 case HloOpcode::kTranspose:
3239 return UseKind::Permuting(this);
3240 case HloOpcode::kPad:
3241 // Pad reuses the padding value but not the padded array elements.
3242 return operand_num > 0 ? UseKind::kReuse : UseKind::Permuting(this);
3243 case HloOpcode::kReduce:
3244 // Reduce reuses the init values but not the operand array elements.
3245 return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
3246 ? UseKind::kReuse
3247 : UseKind::Permuting(this);
3248 case HloOpcode::kFusion:
3249 // Uses the memoizing, recursive computation defined above.
3250 return FusionReusesParamElements::Compute(operand_num,
3251 *fused_expression_root());
3252 case HloOpcode::kDot:
3253 // Matrix-vector dots do not reuse the matrix operand.
3254 if (shape().dimensions_size() <= 1) {
3255 if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
3256 (operand_num == 1 && operand(0)->shape().rank() <= 1)) {
3257 return UseKind::kUse;
3258 }
3259 }
3260 return UseKind::kReuse;
3261 case HloOpcode::kDynamicUpdateSlice:
3262 // Dynamic-update-slice reuses only start_indices.
3263 if (operand_num == 0 || operand_num == 1) {
3264 return UseKind::kUse;
3265 }
3266 return UseKind::kReuse;
3267 case HloOpcode::kGather:
3268 // Gather reads its indices in a linear fashion, and it permutes the
3269 // vector it's gathering from.
3270 return operand_num == 0 ? UseKind::kUse : UseKind::Permuting(this);
3271 default:
3272 return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3273 }
3274 }
3275
3276 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3277 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3278 if (HloOpcode::kReshape != opcode_) {
3279 return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3280 }
3281 return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3282 shape_);
3283 }
3284
ToString(HloInstruction::FusionKind kind)3285 string ToString(HloInstruction::FusionKind kind) {
3286 switch (kind) {
3287 case HloInstruction::FusionKind::kLoop:
3288 return "kLoop";
3289 case HloInstruction::FusionKind::kInput:
3290 return "kInput";
3291 case HloInstruction::FusionKind::kOutput:
3292 return "kOutput";
3293 case HloInstruction::FusionKind::kCustom:
3294 return "kCustom";
3295 }
3296 }
3297
StringToFusionKind(const string & kind_name)3298 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3299 const string& kind_name) {
3300 if (kind_name == "kLoop") {
3301 return HloInstruction::FusionKind::kLoop;
3302 }
3303 if (kind_name == "kInput") {
3304 return HloInstruction::FusionKind::kInput;
3305 }
3306 if (kind_name == "kOutput") {
3307 return HloInstruction::FusionKind::kOutput;
3308 }
3309 if (kind_name == "kCustom") {
3310 return HloInstruction::FusionKind::kCustom;
3311 }
3312 return InvalidArgument("Unknown fusion kind: %s", kind_name);
3313 }
3314
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)3315 string FrontendAttributesToString(
3316 const FrontendAttributes& frontend_attributes) {
3317 std::vector<std::pair<string, string>> sorted_attributes(
3318 frontend_attributes.map().begin(), frontend_attributes.map().end());
3319 absl::c_sort(sorted_attributes);
3320 return absl::StrFormat(
3321 "{%s}", absl::StrJoin(sorted_attributes, ",", absl::PairFormatter("=")));
3322 }
3323
PaddingConfigToString(const PaddingConfig & padding)3324 string PaddingConfigToString(const PaddingConfig& padding) {
3325 bool has_interior_padding =
3326 absl::c_any_of(padding.dimensions(),
3327 [](const PaddingConfig::PaddingConfigDimension& dim) {
3328 return dim.interior_padding() != 0;
3329 });
3330 return StrJoin(
3331 padding.dimensions(), "x",
3332 [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3333 StrAppend(
3334 out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3335 has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3336 });
3337 }
3338
OpMetadataToString(const OpMetadata & metadata)3339 string OpMetadataToString(const OpMetadata& metadata) {
3340 std::vector<string> result;
3341 if (!metadata.op_type().empty()) {
3342 result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
3343 }
3344 if (!metadata.op_name().empty()) {
3345 result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
3346 }
3347 if (!metadata.source_file().empty()) {
3348 result.push_back(
3349 StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
3350 }
3351 if (metadata.source_line() != 0) {
3352 result.push_back(StrCat("source_line=", metadata.source_line()));
3353 }
3354 return StrJoin(result, " ");
3355 }
3356
RandomDistributionToString(const RandomDistribution & distribution)3357 string RandomDistributionToString(const RandomDistribution& distribution) {
3358 return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3359 }
3360
PrecisionToString(const PrecisionConfig::Precision & precision)3361 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3362 return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3363 }
3364
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3365 string ConvolutionDimensionNumbersToString(
3366 const ConvolutionDimensionNumbers& dnums) {
3367 // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3368 // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3369 std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3370 lhs_dims[dnums.input_batch_dimension()] = 'b';
3371 lhs_dims[dnums.input_feature_dimension()] = 'f';
3372 for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3373 lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3374 }
3375
3376 std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3377 rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3378 rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3379 for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3380 rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3381 }
3382
3383 std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3384 output_dims[dnums.output_batch_dimension()] = 'b';
3385 output_dims[dnums.output_feature_dimension()] = 'f';
3386 for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3387 output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3388 }
3389
3390 return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3391 StrJoin(output_dims, ""));
3392 }
3393
ReplicaGroupsToString(const std::vector<ReplicaGroup> & replica_groups)3394 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) {
3395 std::vector<string> replica_group_str;
3396 replica_group_str.reserve(replica_groups.size());
3397 for (const ReplicaGroup& group : replica_groups) {
3398 replica_group_str.push_back(
3399 StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
3400 }
3401 return StrCat("{", StrJoin(replica_group_str, ","), "}");
3402 }
3403
StringToRandomDistribution(const string & name)3404 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3405 static std::unordered_map<string, RandomDistribution>* map = [] {
3406 static auto* map = new std::unordered_map<string, RandomDistribution>;
3407 for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3408 if (RandomDistribution_IsValid(i)) {
3409 auto value = static_cast<RandomDistribution>(i);
3410 (*map)[RandomDistributionToString(value)] = value;
3411 }
3412 }
3413 return map;
3414 }();
3415 auto found = map->find(absl::AsciiStrToLower(name));
3416 if (found == map->end()) {
3417 return InvalidArgument("Unknown distribution");
3418 }
3419 return found->second;
3420 }
3421
StringToPrecision(const string & name)3422 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
3423 static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
3424 static auto* map =
3425 new std::unordered_map<string, PrecisionConfig::Precision>;
3426 for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
3427 if (PrecisionConfig::Precision_IsValid(i)) {
3428 auto value = static_cast<PrecisionConfig::Precision>(i);
3429 (*map)[PrecisionToString(value)] = value;
3430 }
3431 }
3432 return map;
3433 }();
3434 auto found = map->find(absl::AsciiStrToLower(name));
3435 if (found == map->end()) {
3436 return InvalidArgument("Unknown distribution");
3437 }
3438 return found->second;
3439 }
3440
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3441 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3442 return os << ToString(kind);
3443 }
3444
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const3445 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
3446 const HloInstruction* const& rhs) const {
3447 if (rhs == nullptr) {
3448 // Nothing compares less than nullptr.
3449 return false;
3450 }
3451 if (lhs == nullptr) {
3452 return true;
3453 }
3454 auto lhs_module = lhs->GetModule();
3455 auto rhs_module = rhs->GetModule();
3456 CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
3457 (lhs_module != nullptr && rhs_module != nullptr));
3458 if (lhs_module != nullptr &&
3459 lhs_module->unique_id() != rhs_module->unique_id()) {
3460 return lhs_module->unique_id() < rhs_module->unique_id();
3461 }
3462 return lhs->unique_id() < rhs->unique_id();
3463 }
3464
CouldBeBitcast() const3465 bool HloInstruction::CouldBeBitcast() const {
3466 switch (opcode_) {
3467 case HloOpcode::kTranspose:
3468 return true;
3469 case HloOpcode::kReshape:
3470 return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
3471 default:
3472 return false;
3473 }
3474 }
3475
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const3476 Status HloInstruction::GetBackendConfigInternal(
3477 tensorflow::protobuf::Message* proto) const {
3478 proto->Clear();
3479
3480 // Empty string does not parse as valid JSON, but it's a valid backend config,
3481 // corresponding to the empty proto.
3482 if (backend_config_.empty()) {
3483 return Status::OK();
3484 }
3485 return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
3486 }
3487
set_backend_config(const tensorflow::protobuf::Message & proto)3488 Status HloInstruction::set_backend_config(
3489 const tensorflow::protobuf::Message& proto) {
3490 TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
3491 return Status::OK();
3492 }
3493
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)3494 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
3495 const tensorflow::protobuf::Message& proto) {
3496 string ret;
3497 // Pass ignore_accuracy_loss = true because estimated_cycles field can be
3498 // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
3499 // INT64_MAX, JsonFormat will return an error status, although there is no
3500 // accuracy loss for int64.
3501 TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
3502 proto, &ret, /*ignore_accuracy_loss=*/true));
3503 return ret;
3504 }
3505
precision_config() const3506 const PrecisionConfig& HloInstruction::precision_config() const {
3507 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3508 return convolution->precision_config();
3509 }
3510 if (auto* dot = DynCast<HloDotInstruction>(this)) {
3511 return dot->precision_config();
3512 }
3513 LOG(FATAL) << "Unimplemented method.";
3514 }
3515
mutable_precision_config()3516 PrecisionConfig* HloInstruction::mutable_precision_config() {
3517 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3518 return convolution->mutable_precision_config();
3519 }
3520 if (auto* dot = DynCast<HloDotInstruction>(this)) {
3521 return dot->mutable_precision_config();
3522 }
3523 LOG(FATAL) << "Unimplemented method.";
3524 }
3525
GetModule() const3526 HloModule* HloInstruction::GetModule() const {
3527 if (parent_) {
3528 return parent_->parent();
3529 }
3530 return nullptr;
3531 }
3532
UniquifyName(NameUniquer * name_uniquer)3533 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3534 string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3535 name_ = name_uniquer->GetUniqueName(name_);
3536 }
3537
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3538 void HloInstruction::set_outer_dimension_partitions(
3539 const std::vector<int64>& outer_dimension_partitions) {
3540 outer_dimension_partitions_ = outer_dimension_partitions;
3541 }
3542
3543 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const3544 int64 HloInstruction::feature_index() const {
3545 return Cast<HloBatchNormInstruction>(this)->feature_index();
3546 }
3547
epsilon() const3548 float HloInstruction::epsilon() const {
3549 return Cast<HloBatchNormInstruction>(this)->epsilon();
3550 }
3551
fft_type() const3552 FftType HloInstruction::fft_type() const {
3553 return Cast<HloFftInstruction>(this)->fft_type();
3554 }
3555
fft_length() const3556 const std::vector<int64>& HloInstruction::fft_length() const {
3557 return Cast<HloFftInstruction>(this)->fft_length();
3558 }
3559
concatenate_dimension() const3560 int64 HloInstruction::concatenate_dimension() const {
3561 return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
3562 }
3563
dimension() const3564 int64 HloInstruction::dimension() const {
3565 if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
3566 return set_size->dimension();
3567 }
3568 return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
3569 }
3570
inferred_dimension() const3571 int64 HloInstruction::inferred_dimension() const {
3572 return Cast<HloReshapeInstruction>(this)->inferred_dimension();
3573 }
3574
IsRank2Transpose() const3575 bool HloInstruction::IsRank2Transpose() const {
3576 auto transpose = DynCast<HloTransposeInstruction>(this);
3577 return transpose != nullptr && transpose->IsRank2Transpose();
3578 }
3579
slice_starts(int64 dimension) const3580 int64 HloInstruction::slice_starts(int64 dimension) const {
3581 return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
3582 }
3583
slice_starts() const3584 const std::vector<int64>& HloInstruction::slice_starts() const {
3585 return Cast<HloSliceInstruction>(this)->slice_starts();
3586 }
3587
slice_limits(int64 dimension) const3588 int64 HloInstruction::slice_limits(int64 dimension) const {
3589 return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
3590 }
3591
slice_limits() const3592 const std::vector<int64>& HloInstruction::slice_limits() const {
3593 return Cast<HloSliceInstruction>(this)->slice_limits();
3594 }
3595
slice_strides(int64 dimension) const3596 int64 HloInstruction::slice_strides(int64 dimension) const {
3597 return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
3598 }
3599
slice_strides() const3600 const std::vector<int64>& HloInstruction::slice_strides() const {
3601 return Cast<HloSliceInstruction>(this)->slice_strides();
3602 }
3603
literal() const3604 const Literal& HloInstruction::literal() const {
3605 return Cast<HloConstantInstruction>(this)->literal();
3606 }
3607
IsConstant() const3608 bool HloInstruction::IsConstant() const {
3609 return DynCast<HloConstantInstruction>(this) != nullptr;
3610 }
3611
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3612 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3613 const ShapeIndex& shape_index) {
3614 Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
3615 }
3616
TracingTag() const3617 string HloInstruction::TracingTag() const {
3618 return Cast<HloTraceInstruction>(this)->TracingTag();
3619 }
3620
AddFusionOperand(HloInstruction * new_operand)3621 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
3622 return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
3623 }
3624
3625 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)3626 void HloInstruction::MergeFusionInstruction(
3627 HloInstruction* instruction_to_merge) {
3628 return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
3629 Cast<HloFusionInstruction>(instruction_to_merge));
3630 }
3631
3632 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)3633 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
3634 HloInstruction* instruction_to_merge) {
3635 return Cast<HloFusionInstruction>(this)
3636 ->MergeFusionInstructionIntoMultiOutput(
3637 Cast<HloFusionInstruction>(instruction_to_merge));
3638 }
3639
FuseInstruction(HloInstruction * instruction_to_fuse)3640 HloInstruction* HloInstruction::FuseInstruction(
3641 HloInstruction* instruction_to_fuse) {
3642 return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
3643 }
3644
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)3645 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
3646 HloInstruction* instruction_to_fuse) {
3647 return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
3648 instruction_to_fuse);
3649 }
3650
fused_instructions_computation() const3651 HloComputation* HloInstruction::fused_instructions_computation() const {
3652 return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
3653 }
3654
fused_expression_root() const3655 HloInstruction* HloInstruction::fused_expression_root() const {
3656 return Cast<HloFusionInstruction>(this)->fused_expression_root();
3657 }
3658
3659 const tensorflow::gtl::iterator_range<UnwrappingIterator<
3660 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const3661 HloInstruction::fused_instructions() const {
3662 return Cast<HloFusionInstruction>(this)->fused_instructions();
3663 }
3664
3665 const tensorflow::gtl::iterator_range<
3666 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()3667 HloInstruction::fused_instructions() {
3668 return Cast<HloFusionInstruction>(this)->fused_instructions();
3669 }
3670
fused_instruction_count() const3671 int64 HloInstruction::fused_instruction_count() const {
3672 return Cast<HloFusionInstruction>(this)->fused_instruction_count();
3673 }
3674
fused_parameter(int64 parameter_number) const3675 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
3676 return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
3677 }
3678
fused_parameters() const3679 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
3680 return Cast<HloFusionInstruction>(this)->fused_parameters();
3681 }
3682
IsMultiOutputFusion() const3683 const bool HloInstruction::IsMultiOutputFusion() const {
3684 const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
3685 return fusion != nullptr && fusion->IsMultiOutputFusion();
3686 }
3687
fusion_kind() const3688 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
3689 return Cast<HloFusionInstruction>(this)->fusion_kind();
3690 }
3691
set_fusion_kind(FusionKind kind)3692 void HloInstruction::set_fusion_kind(FusionKind kind) {
3693 return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
3694 }
3695
random_distribution() const3696 RandomDistribution HloInstruction::random_distribution() const {
3697 return Cast<HloRngInstruction>(this)->random_distribution();
3698 }
3699
parameter_number() const3700 int64 HloInstruction::parameter_number() const {
3701 return Cast<HloParameterInstruction>(this)->parameter_number();
3702 }
3703
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)3704 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
3705 absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
3706 return Cast<HloParameterInstruction>(this)
3707 ->set_parameter_replicated_at_leaf_buffers(
3708 parameter_replicated_at_leaf_buffers);
3709 }
3710
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)3711 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
3712 const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
3713 return Cast<HloParameterInstruction>(this)
3714 ->set_parameter_replicated_at_leaf_buffers(
3715 parameter_replicated_at_leaf_buffers);
3716 }
3717
3718 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const3719 HloInstruction::parameter_replicated_at_leaf_buffers() const {
3720 return Cast<HloParameterInstruction>(this)
3721 ->parameter_replicated_at_leaf_buffers();
3722 }
3723
tuple_index() const3724 int64 HloInstruction::tuple_index() const {
3725 return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
3726 }
3727
set_tuple_index(int64 new_tuple_index)3728 void HloInstruction::set_tuple_index(int64 new_tuple_index) {
3729 return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
3730 new_tuple_index);
3731 }
3732
exponent_bits() const3733 int32 HloInstruction::exponent_bits() const {
3734 return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
3735 }
3736
mantissa_bits() const3737 int32 HloInstruction::mantissa_bits() const {
3738 return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
3739 }
3740
infeed_config() const3741 string HloInstruction::infeed_config() const {
3742 return Cast<HloInfeedInstruction>(this)->infeed_config();
3743 }
3744
set_infeed_config(const string & config)3745 void HloInstruction::set_infeed_config(const string& config) {
3746 return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
3747 }
3748
outfeed_shape() const3749 const Shape& HloInstruction::outfeed_shape() const {
3750 return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
3751 }
3752
outfeed_config() const3753 const string& HloInstruction::outfeed_config() const {
3754 return Cast<HloOutfeedInstruction>(this)->outfeed_config();
3755 }
3756
replica_groups() const3757 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
3758 return Cast<HloCollectiveInstruction>(this)->replica_groups();
3759 }
3760
3761 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const3762 HloInstruction::source_target_pairs() const {
3763 return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
3764 }
3765
channel_id() const3766 absl::optional<int64> HloInstruction::channel_id() const {
3767 return Cast<HloChannelInstruction>(this)->channel_id();
3768 }
3769
set_channel_id(const absl::optional<int64> & channel_id)3770 void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
3771 return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
3772 }
3773
3774 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const3775 HloInstruction::convolution_dimension_numbers() const {
3776 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3777 return convolution->convolution_dimension_numbers();
3778 }
3779 if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3780 return custom_call->convolution_dimension_numbers();
3781 }
3782 LOG(FATAL) << "Unimplemented method.";
3783 }
3784
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)3785 void HloInstruction::set_convolution_dimension_numbers(
3786 const ConvolutionDimensionNumbers& dnums) {
3787 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3788 convolution->set_convolution_dimension_numbers(dnums);
3789 } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3790 custom_call->set_convolution_dimension_numbers(dnums);
3791 } else {
3792 LOG(FATAL) << "Unimplemented method.";
3793 }
3794 }
3795
feature_group_count() const3796 int64 HloInstruction::feature_group_count() const {
3797 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3798 return convolution->feature_group_count();
3799 }
3800 return Cast<HloCustomCallInstruction>(this)->feature_group_count();
3801 }
3802
set_feature_group_count(int64 feature_group_count)3803 void HloInstruction::set_feature_group_count(int64 feature_group_count) {
3804 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3805 return convolution->set_feature_group_count(feature_group_count);
3806 }
3807 Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
3808 feature_group_count);
3809 }
3810
batch_group_count() const3811 int64 HloInstruction::batch_group_count() const {
3812 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3813 return convolution->batch_group_count();
3814 }
3815 return Cast<HloCustomCallInstruction>(this)->batch_group_count();
3816 }
3817
set_batch_group_count(int64 batch_group_count)3818 void HloInstruction::set_batch_group_count(int64 batch_group_count) {
3819 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3820 return convolution->set_batch_group_count(batch_group_count);
3821 }
3822 Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
3823 batch_group_count);
3824 }
3825
select() const3826 HloComputation* HloInstruction::select() const {
3827 return Cast<HloSelectAndScatterInstruction>(this)->select();
3828 }
3829
scatter() const3830 HloComputation* HloInstruction::scatter() const {
3831 return Cast<HloSelectAndScatterInstruction>(this)->scatter();
3832 }
3833
set_select(HloComputation * computation)3834 void HloInstruction::set_select(HloComputation* computation) {
3835 return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
3836 }
3837
set_scatter(HloComputation * computation)3838 void HloInstruction::set_scatter(HloComputation* computation) {
3839 return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
3840 }
3841
custom_call_target() const3842 const string& HloInstruction::custom_call_target() const {
3843 return Cast<HloCustomCallInstruction>(this)->custom_call_target();
3844 }
3845
padding_config() const3846 const PaddingConfig& HloInstruction::padding_config() const {
3847 return Cast<HloPadInstruction>(this)->padding_config();
3848 }
3849
slice_sizes(int64 dimension) const3850 int64 HloInstruction::slice_sizes(int64 dimension) const {
3851 return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
3852 }
3853
dynamic_slice_sizes() const3854 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
3855 return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
3856 }
3857
gather_dimension_numbers() const3858 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
3859 return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
3860 }
3861
gather_slice_sizes() const3862 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
3863 return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
3864 }
3865
scatter_dimension_numbers() const3866 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
3867 const {
3868 return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
3869 }
3870
dot_dimension_numbers() const3871 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
3872 return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
3873 }
3874
operand_side_metadata() const3875 const DomainMetadata& HloInstruction::operand_side_metadata() const {
3876 return Cast<HloDomainInstruction>(this)->operand_side_metadata();
3877 }
3878
user_side_metadata() const3879 const DomainMetadata& HloInstruction::user_side_metadata() const {
3880 return Cast<HloDomainInstruction>(this)->user_side_metadata();
3881 }
3882
comparison_direction() const3883 ComparisonDirection HloInstruction::comparison_direction() const {
3884 return Cast<HloCompareInstruction>(this)->direction();
3885 }
3886
triangular_solve_options() const3887 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
3888 return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
3889 }
3890
cholesky_options() const3891 const CholeskyOptions& HloInstruction::cholesky_options() const {
3892 return Cast<HloCholeskyInstruction>(this)->cholesky_options();
3893 }
3894
3895 } // namespace xla
3896