1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <algorithm>
19 #include <ostream>
20 #include <set>
21 #include <string>
22 #include <unordered_set>
23 #include <utility>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/container/inlined_vector.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/ascii.h"
31 #include "absl/strings/escaping.h"
32 #include "absl/strings/numbers.h"
33 #include "absl/strings/str_cat.h"
34 #include "absl/strings/str_join.h"
35 #include "absl/strings/string_view.h"
36 #include "absl/types/span.h"
37 #include "tensorflow/compiler/xla/layout_util.h"
38 #include "tensorflow/compiler/xla/literal.h"
39 #include "tensorflow/compiler/xla/protobuf_util.h"
40 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
41 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
42 #include "tensorflow/compiler/xla/service/hlo_computation.h"
43 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
44 #include "tensorflow/compiler/xla/service/hlo_module.h"
45 #include "tensorflow/compiler/xla/service/hlo_op_metadata.h"
46 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
47 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
48 #include "tensorflow/compiler/xla/service/name_uniquer.h"
49 #include "tensorflow/compiler/xla/shape_util.h"
50 #include "tensorflow/compiler/xla/status_macros.h"
51 #include "tensorflow/compiler/xla/types.h"
52 #include "tensorflow/compiler/xla/util.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/core/errors.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/platform/human_readable_json.h"
57 #include "tensorflow/core/platform/logging.h"
58
59 namespace xla {
60
61 using absl::CEscape;
62 using absl::StrAppend;
63 using absl::StrCat;
64 using absl::StrJoin;
65
66 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)67 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
68 const HloInstructionProto& proto,
69 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
70 const absl::flat_hash_map<int64, HloComputation*>& computation_map,
71 bool prohibit_empty_literal) {
72 TF_RET_CHECK(!proto.opcode().empty());
73 HloOpcode opcode;
74 auto opcode_or = StringToHloOpcode(proto.opcode());
75 absl::optional<ComparisonDirection> comparison_direction;
76 if (opcode_or.ok()) {
77 opcode = opcode_or.ConsumeValueOrDie();
78 } else {
79 // Unknown opcode. Try auto-upgrading deprecated "less-than",
80 // "greater-than", etc opcodes, which are now rolled into the kCompare
81 // opcode.
82 if (proto.opcode() == "equal-to") {
83 comparison_direction = ComparisonDirection::kEq;
84 } else if (proto.opcode() == "not-equal-to") {
85 comparison_direction = ComparisonDirection::kNe;
86 } else if (proto.opcode() == "greater-than-or-equal-to") {
87 comparison_direction = ComparisonDirection::kGe;
88 } else if (proto.opcode() == "greater-than") {
89 comparison_direction = ComparisonDirection::kGt;
90 } else if (proto.opcode() == "less-than-or-equal-to") {
91 comparison_direction = ComparisonDirection::kLe;
92 } else if (proto.opcode() == "less-than") {
93 comparison_direction = ComparisonDirection::kLt;
94 }
95 if (comparison_direction) {
96 opcode = HloOpcode::kCompare;
97 } else {
98 return InvalidArgument("Unknown opcode: %s", proto.opcode());
99 }
100 }
101
102 TF_RET_CHECK(proto.has_shape());
103
104 std::unique_ptr<HloInstruction> instruction;
105 const auto operands = [&instruction_map, &proto](int index) {
106 return instruction_map.at(proto.operand_ids(index));
107 };
108 const auto all_operands = [&instruction_map, &proto]() {
109 std::vector<HloInstruction*> result(proto.operand_ids_size());
110 std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
111 result.begin(), [&instruction_map](int64 operand_id) {
112 return instruction_map.at(operand_id);
113 });
114 return result;
115 };
116 const auto computations = [&computation_map, &proto](int index) {
117 return computation_map.at(proto.called_computation_ids(index));
118 };
119 const auto all_computations = [&computation_map, &proto]() {
120 std::vector<HloComputation*> result(proto.called_computation_ids_size());
121 std::transform(proto.called_computation_ids().begin(),
122 proto.called_computation_ids().end(), result.begin(),
123 [&computation_map](int64 computation_id) {
124 return computation_map.at(computation_id);
125 });
126 return result;
127 };
128
129 TF_RET_CHECK(
130 absl::c_all_of(proto.operand_ids(),
131 [&](int64 id) { return instruction_map.contains(id); }))
132 << proto.name() << " instruction contains invalid operand id(s)";
133
134 TF_RET_CHECK(
135 absl::c_all_of(proto.called_computation_ids(),
136 [&](int64 id) { return computation_map.contains(id); }))
137 << proto.name() << " instruction references invalid computation id(s)";
138
139 Shape shape(proto.shape());
140 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
141
142 absl::optional<int> arity = HloOpcodeArity(opcode);
143 if (arity) {
144 TF_RET_CHECK(proto.operand_ids_size() == *arity)
145 << proto.opcode() << " instruction should have " << *arity
146 << " operands but sees " << proto.operand_ids_size();
147 }
148
149 switch (opcode) {
150 // Ops migrated to subclasses.
151 case HloOpcode::kBatchNormTraining:
152 instruction =
153 CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
154 proto.epsilon(), proto.feature_index());
155 break;
156 case HloOpcode::kBatchNormInference:
157 instruction = CreateBatchNormInference(
158 shape, operands(0), operands(1), operands(2), operands(3),
159 operands(4), proto.epsilon(), proto.feature_index());
160 break;
161 case HloOpcode::kBatchNormGrad:
162 instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
163 operands(2), operands(3), operands(4),
164 proto.epsilon(), proto.feature_index());
165 break;
166 case HloOpcode::kFft: {
167 std::vector<int64> fft_length(proto.fft_length().begin(),
168 proto.fft_length().end());
169 instruction = CreateFft(shape, operands(0), proto.fft_type(),
170 absl::Span<const int64>(fft_length));
171 break;
172 }
173 case HloOpcode::kCopyStart: {
174 instruction = CreateCopyStart(shape, operands(0),
175 proto.is_cross_program_prefetch());
176 break;
177 }
178 case HloOpcode::kCompare: {
179 // Auto-upgraded from deprecated opcode skips the following.
180 if (!comparison_direction) {
181 TF_ASSIGN_OR_RETURN(
182 comparison_direction,
183 StringToComparisonDirection(proto.comparison_direction()));
184 }
185 auto comparison_type_str = proto.comparison_type();
186 if (!comparison_type_str.empty()) {
187 // If a comparison type is specified, it *must* be valid.
188 TF_ASSIGN_OR_RETURN(auto comparison_type,
189 StringToComparisonType(comparison_type_str));
190 instruction = CreateCompare(shape, operands(0), operands(1),
191 *comparison_direction, comparison_type);
192 } else {
193 // Allow the specify of comparison type to be optional.
194 // The comparison type will be determined by the types of the operands.
195 instruction = CreateCompare(shape, operands(0), operands(1),
196 *comparison_direction);
197 }
198 break;
199 }
200 case HloOpcode::kTriangularSolve: {
201 instruction = CreateTriangularSolve(shape, operands(0), operands(1),
202 proto.triangular_solve_options());
203 break;
204 }
205 case HloOpcode::kCholesky: {
206 instruction =
207 CreateCholesky(shape, operands(0), proto.cholesky_options());
208 break;
209 }
210 case HloOpcode::kSend:
211 instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
212 proto.is_host_transfer());
213 break;
214 case HloOpcode::kSendDone:
215 instruction = CreateSendDone(operands(0), proto.is_host_transfer());
216 break;
217 case HloOpcode::kRecv:
218 instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
219 proto.channel_id(), proto.is_host_transfer());
220 break;
221 case HloOpcode::kRecvDone:
222 instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
223 break;
224 case HloOpcode::kReverse:
225 instruction = CreateReverse(shape, operands(0),
226 std::vector<int64>(proto.dimensions().begin(),
227 proto.dimensions().end()));
228 break;
229 case HloOpcode::kConcatenate:
230 TF_RET_CHECK(proto.dimensions_size() == 1)
231 << "Concatenate instruction should have 1 dimension but sees "
232 << proto.dimensions_size();
233 instruction =
234 CreateConcatenate(shape, all_operands(), proto.dimensions(0));
235 break;
236 case HloOpcode::kConditional: {
237 TF_RET_CHECK(proto.called_computation_ids_size() > 0)
238 << "conditional should have at least 1 called computation";
239 if (operands(0)->shape().element_type() == PRED) {
240 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
241 << "conditional should have exactly 2 called computations but got "
242 << proto.called_computation_ids_size();
243 }
244 TF_RET_CHECK(proto.operand_ids_size() ==
245 proto.called_computation_ids_size() + 1)
246 << "conditional should have one branch_index operand plus one "
247 "operand per called computation but got "
248 << proto.operand_ids_size() << " operands for "
249 << proto.called_computation_ids_size() << " branch computations";
250 auto cond_operands = all_operands();
251 instruction =
252 CreateConditional(shape, cond_operands[0], all_computations(),
253 absl::MakeSpan(cond_operands).subspan(1));
254 break;
255 }
256 case HloOpcode::kReduce:
257 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
258 << "Reduce instruction should have an even number of operands but "
259 "sees "
260 << proto.operand_ids_size();
261 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
262 << "Reduce instruction should have 1 called computation but sees "
263 << proto.called_computation_ids_size();
264 {
265 const auto reduce_operands = all_operands();
266 auto inputs = absl::MakeSpan(reduce_operands)
267 .subspan(0, reduce_operands.size() / 2);
268 auto init_values =
269 absl::MakeSpan(reduce_operands)
270 .subspan(reduce_operands.size() / 2, reduce_operands.size());
271 instruction =
272 CreateReduce(shape, inputs, init_values,
273 std::vector<int64>(proto.dimensions().begin(),
274 proto.dimensions().end()),
275 computations(0));
276 }
277 break;
278 case HloOpcode::kSort: {
279 TF_RET_CHECK(proto.operand_ids_size() >= 1)
280 << "Sort instruction should have at least 1 operand but has "
281 << proto.operand_ids_size();
282 TF_RET_CHECK(proto.dimensions().size() == 1)
283 << "Sort instruction should have 1 dimension";
284 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
285 << "Sort instruction should one called computation but sees "
286 << proto.called_computation_ids_size();
287 auto sort_operands = all_operands();
288 instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
289 computations(0), proto.is_stable());
290 break;
291 }
292 case HloOpcode::kTranspose:
293 instruction =
294 CreateTranspose(shape, operands(0),
295 std::vector<int64>(proto.dimensions().begin(),
296 proto.dimensions().end()));
297 break;
298 case HloOpcode::kBroadcast:
299 instruction =
300 CreateBroadcast(shape, operands(0),
301 std::vector<int64>(proto.dimensions().begin(),
302 proto.dimensions().end()));
303 break;
304 case HloOpcode::kMap:
305 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
306 << "Map instruction should have 1 called computation but sees "
307 << proto.called_computation_ids_size();
308 instruction = CreateMap(shape, all_operands(), computations(0));
309 break;
310 case HloOpcode::kSlice: {
311 std::vector<int64> slice_starts, slice_limits, slice_strides;
312 for (const HloInstructionProto::SliceDimensions& slice_dimensions :
313 proto.slice_dimensions()) {
314 slice_starts.push_back(slice_dimensions.start());
315 slice_limits.push_back(slice_dimensions.limit());
316 slice_strides.push_back(slice_dimensions.stride());
317 }
318 instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
319 slice_strides);
320 break;
321 }
322 case HloOpcode::kConstant: {
323 // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
324 if (proto.has_literal()) {
325 TF_ASSIGN_OR_RETURN(
326 auto literal,
327 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
328 instruction = CreateConstant(std::move(literal));
329 // Literal's shape may have no/different tiling info.
330 TF_RET_CHECK(Shape::Equal().MinorToMajorOnlyInLayout()(
331 instruction->shape(), shape))
332 << instruction->shape().ToString(true) << " vs "
333 << shape.ToString(true);
334 *instruction->mutable_shape() = shape;
335 } else {
336 instruction = absl::make_unique<HloConstantInstruction>(shape);
337 }
338 break;
339 }
340 case HloOpcode::kTrace: {
341 TF_RET_CHECK(proto.has_literal());
342 TF_ASSIGN_OR_RETURN(
343 auto literal,
344 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
345 instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
346 break;
347 }
348 case HloOpcode::kFusion: {
349 // In the proto, fused computations are held exclusively within the
350 // HloInstructionProto and do not appear as an HloComputationProto within
351 // the HloModuleProto.
352 TF_RET_CHECK(!proto.fusion_kind().empty());
353 TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
354 StringToFusionKind(proto.fusion_kind()));
355
356 // Find the fused computation and set its fusion instruction.
357 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
358 << "Expect 1 called computation for fusion instruction but sees "
359 << proto.called_computation_ids_size();
360 const int64 fusion_id = proto.called_computation_ids(0);
361 auto* fused_computation =
362 tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
363 TF_RET_CHECK(fused_computation != nullptr)
364 << "No fusion computation with id " << fusion_id;
365 instruction =
366 CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
367 break;
368 }
369 case HloOpcode::kRng:
370 instruction = CreateRng(shape, proto.distribution(), all_operands());
371 break;
372 case HloOpcode::kRngBitGenerator:
373 instruction =
374 CreateRngBitGenerator(shape, operands(0), proto.rng_algorithm());
375 break;
376 case HloOpcode::kRngGetAndUpdateState:
377 instruction = CreateRngGetAndUpdateState(shape, proto.delta());
378 break;
379 case HloOpcode::kParameter:
380 instruction =
381 CreateParameter(proto.parameter_number(), shape, proto.name());
382 if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
383 instruction->set_parameter_replicated_at_leaf_buffers(
384 proto.parameter_replication().replicated_at_leaf_buffers());
385 }
386 break;
387 case HloOpcode::kGetTupleElement:
388 instruction =
389 CreateGetTupleElement(shape, operands(0), proto.tuple_index());
390 break;
391 case HloOpcode::kReducePrecision:
392 instruction = CreateReducePrecision(
393 shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
394 break;
395 case HloOpcode::kInfeed: {
396 TF_RET_CHECK(shape.IsTuple() &&
397 (ShapeUtil::TupleElementCount(shape) == 2))
398 << "Infeed should have a tuple shape with 2 operands, but has: "
399 << shape;
400 const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
401 instruction =
402 CreateInfeed(data_shape, operands(0), proto.infeed_config());
403 } break;
404 case HloOpcode::kOutfeed: {
405 Shape outfeed_shape(proto.outfeed_shape());
406 TF_RETURN_IF_ERROR(
407 ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
408 instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
409 proto.outfeed_config());
410 break;
411 }
412 case HloOpcode::kAllGather: {
413 absl::optional<int64> channel_id;
414 if (proto.channel_id() > 0) {
415 channel_id = proto.channel_id();
416 }
417
418 TF_RET_CHECK(proto.dimensions_size() == 1)
419 << "AllGather cannot have more than 1 all-gather dimensions";
420 TF_RET_CHECK(all_operands().size() == 1)
421 << "AllGather must have a single operand";
422 int64 all_gather_dimension = proto.dimensions(0);
423 instruction = CreateAllGather(
424 shape, operands(0), all_gather_dimension,
425 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
426 proto.replica_groups().end()),
427 proto.constrain_layout(), channel_id, proto.use_global_device_ids());
428 break;
429 }
430 case HloOpcode::kAllReduce: {
431 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
432 << "AllReduce should have 1 called computation but sees "
433 << proto.called_computation_ids_size();
434 TF_RET_CHECK(proto.channel_id() <= 0 || proto.all_reduce_id() <= 0)
435 << "AllReduce cannot have both channel_id() and all_reduce_id()";
436 absl::optional<int64> channel_id;
437 if (proto.channel_id() > 0) {
438 channel_id = proto.channel_id();
439 }
440 if (proto.all_reduce_id() > 0) {
441 channel_id = proto.all_reduce_id();
442 }
443 instruction = CreateAllReduce(
444 shape, all_operands(), computations(0),
445 /*replica_groups=*/
446 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
447 proto.replica_groups().end()),
448 /*constrain_layout=*/proto.constrain_layout(),
449 /*channel_id=*/channel_id,
450 /*use_global_device_ids=*/proto.use_global_device_ids());
451 break;
452 }
453 case HloOpcode::kAllToAll: {
454 absl::optional<int64> channel_id;
455 if (proto.channel_id() > 0) {
456 channel_id = proto.channel_id();
457 }
458 absl::optional<int64> split_dimension;
459 if (proto.dimensions_size() > 0) {
460 TF_RET_CHECK(proto.dimensions_size() == 1)
461 << "AllToAll cannot have more than 1 dimension (split dimension)";
462 TF_RET_CHECK(all_operands().size() == 1)
463 << "AllToAll must have a single operand when the split dimension "
464 "is specified";
465 split_dimension = proto.dimensions(0);
466 }
467 instruction = CreateAllToAll(
468 shape, all_operands(),
469 /*replica_groups=*/
470 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
471 proto.replica_groups().end()),
472 /*constrain_layout=*/proto.constrain_layout(),
473 /*channel_id=*/channel_id, split_dimension);
474 break;
475 }
476 case HloOpcode::kCollectivePermute:
477 case HloOpcode::kCollectivePermuteStart: {
478 std::vector<std::pair<int64, int64>> source_target_pairs(
479 proto.source_target_pairs_size());
480 absl::optional<int64> channel_id;
481 if (proto.channel_id() > 0) {
482 channel_id = proto.channel_id();
483 }
484 for (int i = 0; i < source_target_pairs.size(); i++) {
485 source_target_pairs[i].first = proto.source_target_pairs(i).source();
486 source_target_pairs[i].second = proto.source_target_pairs(i).target();
487 }
488
489 if (opcode == HloOpcode::kCollectivePermute) {
490 instruction = CreateCollectivePermute(shape, operands(0),
491 source_target_pairs, channel_id);
492 } else if (opcode == HloOpcode::kCollectivePermuteStart) {
493 instruction = CreateCollectivePermuteStart(
494 shape, operands(0), source_target_pairs, channel_id);
495 } else {
496 LOG(FATAL) << "Expect CollectivePermute or CollectivePermuteStart, "
497 << "but got " << HloOpcodeString(opcode);
498 }
499 break;
500 }
501 case HloOpcode::kReplicaId: {
502 instruction = CreateReplicaId(shape);
503 break;
504 }
505 case HloOpcode::kPartitionId: {
506 instruction = CreatePartitionId(shape);
507 break;
508 }
509 case HloOpcode::kConvolution: {
510 TF_RET_CHECK(proto.has_window());
511 TF_RET_CHECK(proto.has_convolution_dimension_numbers());
512 PrecisionConfig precision_config = proto.precision_config();
513 precision_config.mutable_operand_precision()->Resize(
514 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
515 instruction = CreateConvolve(
516 shape, operands(0), operands(1),
517 std::max<int64>(proto.feature_group_count(), 1),
518 std::max<int64>(proto.batch_group_count(), 1), proto.window(),
519 proto.convolution_dimension_numbers(), precision_config);
520 break;
521 }
522 case HloOpcode::kReduceWindow:
523 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
524 << "Reduce window should have an even number of operands but "
525 "sees "
526 << proto.operand_ids_size();
527 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
528 << "ReduceWindow should have 1 called computation but sees "
529 << proto.called_computation_ids_size();
530 {
531 const auto reduce_operands = all_operands();
532 auto inputs = absl::MakeSpan(reduce_operands)
533 .subspan(0, reduce_operands.size() / 2);
534 auto init_values =
535 absl::MakeSpan(reduce_operands)
536 .subspan(reduce_operands.size() / 2, reduce_operands.size());
537 instruction = CreateReduceWindow(shape, inputs, init_values,
538 proto.window(), computations(0));
539 }
540 break;
541 case HloOpcode::kSelectAndScatter:
542 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
543 << "SelectAndScatter should have 2 called computations but sees "
544 << proto.called_computation_ids_size();
545 instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
546 proto.window(), operands(1),
547 operands(2), computations(1));
548 break;
549 case HloOpcode::kCustomCall: {
550 if (proto.constrain_layout()) {
551 // A proto RepeatedPtrField cannot be converted to a Span (it is a
552 // vector of pointers essentially) so create a vector of shapes to pass
553 // in.
554 std::vector<Shape> operand_shapes;
555 for (const ShapeProto& shape_proto :
556 proto.operand_shapes_with_layout()) {
557 operand_shapes.emplace_back(shape_proto);
558 }
559 instruction =
560 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
561 operand_shapes, proto.backend_config());
562 } else {
563 if (proto.called_computation_ids_size() == 1) {
564 instruction = CreateCustomCall(shape, all_operands(), computations(0),
565 proto.custom_call_target(),
566 proto.backend_config());
567 } else if (proto.called_computation_ids_size() > 1) {
568 instruction = CreateCustomCall(
569 shape, all_operands(), all_computations(),
570 proto.custom_call_target(), proto.backend_config());
571
572 } else {
573 instruction = CreateCustomCall(shape, all_operands(),
574 proto.custom_call_target(),
575 proto.backend_config());
576 }
577 }
578 auto custom_call_instr =
579 Cast<HloCustomCallInstruction>(instruction.get());
580 if (proto.has_window()) {
581 custom_call_instr->set_window(proto.window());
582 }
583 if (proto.has_literal()) {
584 TF_ASSIGN_OR_RETURN(
585 auto literal,
586 Literal::CreateFromProto(proto.literal(), prohibit_empty_literal));
587 custom_call_instr->set_literal(std::move(literal));
588 }
589 if (proto.has_convolution_dimension_numbers()) {
590 custom_call_instr->set_convolution_dimension_numbers(
591 proto.convolution_dimension_numbers());
592 }
593 custom_call_instr->set_feature_group_count(
594 std::max(static_cast<int64>(proto.feature_group_count()), int64{1}));
595 custom_call_instr->set_batch_group_count(
596 std::max(static_cast<int64>(proto.batch_group_count()), int64{1}));
597 custom_call_instr->set_custom_call_has_side_effect(
598 proto.custom_call_has_side_effect());
599 custom_call_instr->set_padding_type(proto.padding_type());
600
601 PrecisionConfig precision_config = proto.precision_config();
602 precision_config.mutable_operand_precision()->Resize(
603 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
604 *custom_call_instr->mutable_precision_config() = precision_config;
605 std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
606 output_to_operand_aliasing;
607 for (const auto& aliasing : proto.custom_call_output_operand_aliasing()) {
608 output_to_operand_aliasing.emplace_back(
609 ShapeIndex(aliasing.output_shape_index().begin(),
610 aliasing.output_shape_index().end()),
611 std::pair<int64, ShapeIndex>{
612 aliasing.operand_index(),
613 ShapeIndex(aliasing.operand_shape_index().begin(),
614 aliasing.operand_shape_index().end())});
615 }
616 custom_call_instr->set_output_to_operand_aliasing(
617 std::move(output_to_operand_aliasing));
618 break;
619 }
620 case HloOpcode::kPad:
621 TF_RET_CHECK(proto.has_padding_config());
622 instruction =
623 CreatePad(shape, operands(0), operands(1), proto.padding_config());
624 break;
625 case HloOpcode::kDynamicSlice: {
626 std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
627 absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
628 TF_RET_CHECK(proto.operand_ids_size() >= 1)
629 << "DynamicSlice instruction should have at least 1 operands but "
630 "sees "
631 << proto.operand_ids_size();
632 // TODO(b/118437727): Old form, make the check unconditional.
633 if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
634 auto expected_operands = 1 + operands(0)->shape().rank();
635 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
636 << "DynamicSlice instruction should have " << expected_operands
637 << " operands, but has " << proto.operand_ids_size();
638 }
639 const auto& operand_vector = all_operands();
640 instruction = CreateDynamicSlice(
641 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
642 slice_sizes);
643 break;
644 }
645 case HloOpcode::kDynamicUpdateSlice: {
646 TF_RET_CHECK(proto.operand_ids_size() >= 2)
647 << "DynamicUpdateSlice instruction should have at least 2 operands "
648 "but sees "
649 << proto.operand_ids_size();
650 // TODO(b/118437727): Old form, make the check unconditional.
651 if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
652 auto expected_operands = 2 + operands(0)->shape().rank();
653 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
654 << "DynamicUpdateSlice instruction should have "
655 << expected_operands << " operands, but has "
656 << proto.operand_ids_size();
657 }
658 const auto& operand_vector = all_operands();
659 instruction =
660 CreateDynamicUpdateSlice(shape, operands(0), operands(1),
661 absl::MakeSpan(operand_vector).subspan(2));
662
663 break;
664 }
665 case HloOpcode::kGather: {
666 TF_RET_CHECK(proto.has_gather_dimension_numbers())
667 << "Gather instruction should have GatherDimensionNumbers set.";
668 auto gather_dimension_numbers = absl::make_unique<GatherDimensionNumbers>(
669 proto.gather_dimension_numbers());
670 std::vector<int64> gather_slice_sizes;
671 for (int64 bound : proto.gather_slice_sizes()) {
672 gather_slice_sizes.push_back(bound);
673 }
674 instruction = CreateGather(shape, operands(0), operands(1),
675 *gather_dimension_numbers, gather_slice_sizes,
676 proto.indices_are_sorted());
677 break;
678 }
679 case HloOpcode::kScatter: {
680 TF_RET_CHECK(proto.has_scatter_dimension_numbers())
681 << "Scatter instruction should have ScatterDimensionNumbers set.";
682 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
683 << "Scatter instruction should have 1 called computation but sees "
684 << proto.called_computation_ids_size();
685 auto scatter_dimension_numbers =
686 absl::make_unique<ScatterDimensionNumbers>(
687 proto.scatter_dimension_numbers());
688 instruction =
689 CreateScatter(shape, operands(0), operands(1), operands(2),
690 computations(0), *scatter_dimension_numbers,
691 proto.indices_are_sorted(), proto.unique_indices());
692 break;
693 }
694 case HloOpcode::kIota:
695 TF_RET_CHECK(proto.dimensions_size() == 1)
696 << "Iota instruction should have 1 dimension but sees "
697 << proto.dimensions_size();
698 instruction = CreateIota(shape, proto.dimensions(0));
699 break;
700 case HloOpcode::kDot: {
701 TF_RET_CHECK(proto.has_dot_dimension_numbers())
702 << "Dot instruction should have dot_dimension_numbers.";
703 PrecisionConfig precision_config = proto.precision_config();
704 precision_config.mutable_operand_precision()->Resize(
705 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
706 instruction = absl::make_unique<HloDotInstruction>(
707 shape, operands(0), operands(1), proto.dot_dimension_numbers(),
708 precision_config);
709 break;
710 }
711 case HloOpcode::kDomain: {
712 std::shared_ptr<const HloSharding> entry_hlo_sharding;
713 std::shared_ptr<const HloSharding> exit_hlo_sharding;
714 if (proto.has_domain_entry_sharding()) {
715 TF_ASSIGN_OR_RETURN(
716 HloSharding sharding,
717 HloSharding::FromProto(proto.domain_entry_sharding()));
718 entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
719 }
720 if (proto.has_domain_exit_sharding()) {
721 TF_ASSIGN_OR_RETURN(
722 HloSharding sharding,
723 HloSharding::FromProto(proto.domain_exit_sharding()));
724 exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
725 }
726 instruction = absl::make_unique<HloDomainInstruction>(
727 shape, operands(0),
728 absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
729 absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
730 break;
731 }
732 case HloOpcode::kGetDimensionSize:
733 TF_RET_CHECK(proto.dimensions_size() == 1);
734 instruction =
735 CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
736 break;
737 case HloOpcode::kSetDimensionSize:
738 TF_RET_CHECK(proto.dimensions_size() == 1);
739 instruction = CreateSetDimensionSize(shape, operands(0), operands(1),
740 proto.dimensions(0));
741 break;
742 case HloOpcode::kReshape: {
743 int64 inferred_dimension = -1;
744 if (!proto.dimensions().empty()) {
745 inferred_dimension = proto.dimensions()[0];
746 }
747 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
748 ShapeUtil::ElementsIn(shape) ==
749 ShapeUtil::ElementsIn(operands(0)->shape()))
750 << "shape: " << ShapeUtil::HumanString(shape)
751 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
752 instruction = CreateReshape(shape, operands(0), inferred_dimension);
753 break;
754 }
755 case HloOpcode::kDynamicReshape: {
756 TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() &&
757 ShapeUtil::ElementsIn(shape) ==
758 ShapeUtil::ElementsIn(operands(0)->shape()))
759 << "shape: " << ShapeUtil::HumanString(shape)
760 << " operand: " << ShapeUtil::HumanString(operands(0)->shape());
761 const auto& operand_vector = all_operands();
762 instruction = CreateDynamicReshape(
763 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1));
764 break;
765 }
766 default: {
767 instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
768 for (const int64 operand_id : proto.operand_ids()) {
769 instruction->AppendOperand(instruction_map.at(operand_id));
770 }
771 if (instruction->opcode() != HloOpcode::kFusion) {
772 if (instruction->opcode() == HloOpcode::kCall) {
773 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
774 << "Call should have 1 called computation but has "
775 << proto.called_computation_ids_size();
776 }
777 for (const int64 computation_id : proto.called_computation_ids()) {
778 instruction->called_computations_.push_back(
779 computation_map.at(computation_id));
780 }
781 }
782 TF_RET_CHECK(!proto.has_precision_config())
783 << instruction->opcode() << proto.DebugString();
784 TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
785 break;
786 }
787 }
788
789 for (const int64 predecessor_id : proto.control_predecessor_ids()) {
790 TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
791 << "No instruction with id " << predecessor_id;
792 TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
793 ->AddControlDependencyTo(instruction.get()));
794 }
795
796 TF_RET_CHECK(!proto.name().empty());
797 instruction->SetAndSanitizeName(proto.name());
798 instruction->metadata_ = proto.metadata();
799 instruction->backend_config_ = proto.backend_config();
800 instruction->outer_dimension_partitions_.assign(
801 proto.outer_dimension_partitions().begin(),
802 proto.outer_dimension_partitions().end());
803
804 TF_RET_CHECK(proto.id() >= 0)
805 << "Instruction with negative id: " << proto.id();
806 TF_RET_CHECK(proto.id() <= INT_MAX)
807 << "Instruction with id > INT_MAX: " << proto.id();
808 instruction->unique_id_ = proto.id();
809
810 if (proto.has_sharding()) {
811 TF_ASSIGN_OR_RETURN(const auto& sharding,
812 HloSharding::FromProto(proto.sharding()));
813 instruction->set_sharding(sharding);
814 }
815
816 if (proto.has_frontend_attributes()) {
817 instruction->set_frontend_attributes(proto.frontend_attributes());
818 }
819
820 return std::move(instruction);
821 }
822
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)823 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
824 int64 parameter_number, const Shape& shape, const string& name) {
825 return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
826 name);
827 }
828
CreateTrace(const string & tag,HloInstruction * operand)829 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
830 const string& tag, HloInstruction* operand) {
831 return absl::make_unique<HloTraceInstruction>(tag, operand);
832 }
833
CreateConstant(Literal literal)834 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
835 Literal literal) {
836 return absl::make_unique<HloConstantInstruction>(std::move(literal));
837 }
838
CreateIota(const Shape & shape,int64 iota_dimension)839 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
840 const Shape& shape, int64 iota_dimension) {
841 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
842 }
843
844 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)845 HloInstruction::CreateGetTupleElement(const Shape& shape,
846 HloInstruction* operand, int64 index) {
847 return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
848 index);
849 }
850
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)851 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
852 const Shape& shape, RandomDistribution distribution,
853 absl::Span<HloInstruction* const> parameters) {
854 return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
855 }
856
857 /* static */ std::unique_ptr<HloInstruction>
CreateRngGetAndUpdateState(const Shape & shape,int64 delta)858 HloInstruction::CreateRngGetAndUpdateState(const Shape& shape, int64 delta) {
859 return absl::make_unique<HloRngGetAndUpdateStateInstruction>(shape, delta);
860 }
861
862 /* static */ std::unique_ptr<HloInstruction>
CreateRngBitGenerator(const Shape & shape,HloInstruction * state,RandomAlgorithm algorithm)863 HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
864 RandomAlgorithm algorithm) {
865 return absl::make_unique<HloRngBitGeneratorInstruction>(shape, state,
866 algorithm);
867 }
868
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)869 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
870 const Shape& shape, HloOpcode opcode,
871 absl::Span<HloInstruction* const> operands) {
872 if (opcode == HloOpcode::kCopy) {
873 // It is impossible to copy an opaque shape, we don't know how big it is.
874 CHECK(!shape.IsOpaque());
875 }
876 auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
877 for (auto operand : operands) {
878 instruction->AppendOperand(operand);
879 }
880 return instruction;
881 }
882
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)883 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
884 const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
885 // Only certain opcodes are supported with CreateUnary: opcodes of unary
886 // instructions with no auxiliary fields.
887 switch (opcode) {
888 case HloOpcode::kAbs:
889 case HloOpcode::kRoundNearestAfz:
890 case HloOpcode::kBitcast:
891 case HloOpcode::kCeil:
892 case HloOpcode::kCollectivePermuteDone:
893 case HloOpcode::kCopy:
894 case HloOpcode::kCopyDone:
895 case HloOpcode::kCos:
896 case HloOpcode::kClz:
897 case HloOpcode::kExp:
898 case HloOpcode::kExpm1:
899 case HloOpcode::kFloor:
900 case HloOpcode::kImag:
901 case HloOpcode::kIsFinite:
902 case HloOpcode::kLog:
903 case HloOpcode::kLog1p:
904 case HloOpcode::kNot:
905 case HloOpcode::kNegate:
906 case HloOpcode::kPopulationCount:
907 case HloOpcode::kReal:
908 case HloOpcode::kRsqrt:
909 case HloOpcode::kLogistic:
910 case HloOpcode::kSign:
911 case HloOpcode::kSin:
912 case HloOpcode::kSqrt:
913 case HloOpcode::kCbrt:
914 case HloOpcode::kTanh:
915 break;
916 default:
917 LOG(FATAL) << "Invalid unary instruction opcode "
918 << HloOpcodeString(opcode);
919 }
920 return CreateNary(shape, opcode, {operand});
921 }
922
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)923 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
924 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
925 HloInstruction* rhs) {
926 // Only certain opcodes are supported with CreateBinary: opcodes of binary
927 // instructions with no auxiliary fields.
928 switch (opcode) {
929 case HloOpcode::kAdd:
930 case HloOpcode::kAtan2:
931 case HloOpcode::kDivide:
932 case HloOpcode::kComplex:
933 case HloOpcode::kMaximum:
934 case HloOpcode::kMinimum:
935 case HloOpcode::kMultiply:
936 case HloOpcode::kPower:
937 case HloOpcode::kRemainder:
938 case HloOpcode::kSubtract:
939 case HloOpcode::kAnd:
940 case HloOpcode::kOr:
941 case HloOpcode::kXor:
942 case HloOpcode::kShiftLeft:
943 case HloOpcode::kShiftRightArithmetic:
944 case HloOpcode::kShiftRightLogical:
945 break;
946 default:
947 LOG(FATAL) << "Invalid binary instruction opcode "
948 << HloOpcodeString(opcode);
949 }
950 return CreateNary(shape, opcode, {lhs, rhs});
951 }
952
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)953 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
954 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
955 HloInstruction* rhs, HloInstruction* ehs) {
956 // Only certain opcodes are supported with CreateTernary: opcodes of ternary
957 // instructions with no auxiliary fields.
958 switch (opcode) {
959 case HloOpcode::kClamp:
960 case HloOpcode::kSelect:
961 case HloOpcode::kTupleSelect:
962 break;
963 default:
964 LOG(FATAL) << "Invalid ternary instruction opcode "
965 << HloOpcodeString(opcode);
966 }
967 return CreateNary(shape, opcode, {lhs, rhs, ehs});
968 }
969
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)970 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
971 const Shape& shape, HloOpcode opcode,
972 absl::Span<HloInstruction* const> operands) {
973 CHECK_EQ(HloOpcode::kTuple, opcode);
974 return CreateNary(shape, opcode, operands);
975 }
976
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)977 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
978 const Shape& shape, absl::Span<HloInstruction* const> operands,
979 HloComputation* map_computation) {
980 return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
981 }
982
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)983 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
984 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
985 int64 feature_group_count, int64 batch_group_count, const Window& window,
986 const ConvolutionDimensionNumbers& dimension_numbers,
987 const PrecisionConfig& precision_config) {
988 return absl::make_unique<HloConvolutionInstruction>(
989 shape, lhs, rhs, feature_group_count, batch_group_count, window,
990 dimension_numbers, precision_config);
991 }
992
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)993 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
994 const Shape& shape, HloInstruction* operand, FftType fft_type,
995 absl::Span<const int64> fft_length) {
996 return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
997 fft_length);
998 }
999
CreateCopyStart(const Shape & shape,HloInstruction * operand,bool is_cross_program_prefetch)1000 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
1001 const Shape& shape, HloInstruction* operand,
1002 bool is_cross_program_prefetch) {
1003 return absl::make_unique<HloCopyStartInstruction>(shape, operand,
1004 is_cross_program_prefetch);
1005 }
1006
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction,absl::optional<Comparison::Type> type)1007 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
1008 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1009 ComparisonDirection direction, absl::optional<Comparison::Type> type) {
1010 return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction,
1011 type);
1012 }
1013
1014 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)1015 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
1016 HloInstruction* b,
1017 const TriangularSolveOptions& options) {
1018 return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
1019 }
1020
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)1021 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
1022 const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
1023 return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
1024 }
1025
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)1026 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
1027 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1028 const DotDimensionNumbers& dimension_numbers,
1029 const PrecisionConfig& precision_config) {
1030 return absl::make_unique<HloDotInstruction>(
1031 shape, lhs, rhs, dimension_numbers, precision_config);
1032 }
1033
1034 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)1035 HloInstruction::CreateReducePrecision(const Shape& shape,
1036 HloInstruction* operand,
1037 const int exponent_bits,
1038 const int mantissa_bits) {
1039 return absl::make_unique<HloReducePrecisionInstruction>(
1040 shape, operand, exponent_bits, mantissa_bits);
1041 }
1042
CreateAllGather(const Shape & shape,HloInstruction * operand,int64 all_gather_dimension,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1043 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllGather(
1044 const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
1045 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
1046 const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1047 return absl::make_unique<HloAllGatherInstruction>(
1048 shape, operand, all_gather_dimension, replica_groups, constrain_layout,
1049 channel_id, use_global_device_ids);
1050 }
1051
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,bool use_global_device_ids)1052 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
1053 const Shape& shape, absl::Span<HloInstruction* const> operands,
1054 HloComputation* reduce_computation,
1055 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
1056 const absl::optional<int64>& channel_id, bool use_global_device_ids) {
1057 return absl::make_unique<HloAllReduceInstruction>(
1058 shape, operands, reduce_computation, replica_groups, constrain_layout,
1059 channel_id, use_global_device_ids);
1060 }
1061
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups,bool constrain_layout,const absl::optional<int64> & channel_id,const absl::optional<int64> & split_dimension)1062 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
1063 const Shape& shape, absl::Span<HloInstruction* const> operands,
1064 const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
1065 const absl::optional<int64>& channel_id,
1066 const absl::optional<int64>& split_dimension) {
1067 return absl::make_unique<HloAllToAllInstruction>(
1068 shape, operands, replica_groups, constrain_layout, channel_id,
1069 split_dimension);
1070 }
1071
1072 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)1073 HloInstruction::CreateCollectivePermute(
1074 const Shape& shape, HloInstruction* operand,
1075 const std::vector<std::pair<int64, int64>>& source_target_pairs,
1076 const absl::optional<int64>& channel_id) {
1077 return absl::make_unique<HloCollectivePermuteInstruction>(
1078 HloOpcode::kCollectivePermute, shape, operand, source_target_pairs,
1079 channel_id);
1080 }
1081
1082 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermuteStart(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs,const absl::optional<int64> & channel_id)1083 HloInstruction::CreateCollectivePermuteStart(
1084 const Shape& shape, HloInstruction* operand,
1085 const std::vector<std::pair<int64, int64>>& source_target_pairs,
1086 const absl::optional<int64>& channel_id) {
1087 return absl::make_unique<HloCollectivePermuteInstruction>(
1088 HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs,
1089 channel_id);
1090 }
1091
CreateReplicaId(const Shape & shape)1092 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId(
1093 const Shape& shape) {
1094 CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1095 << "HloInstruction replica-id must have a shape of u32[], but "
1096 << shape.ToString() << " is specified";
1097 return absl::WrapUnique(new HloInstruction(HloOpcode::kReplicaId, shape));
1098 }
1099
CreatePartitionId(const Shape & shape)1100 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePartitionId(
1101 const Shape& shape) {
1102 CHECK(Shape::Equal().IgnoreLayout()(shape, ShapeUtil::MakeShape(U32, {})))
1103 << "HloInstruction partition-id must have a shape of u32[], but "
1104 << shape.ToString() << " is specified";
1105 return absl::WrapUnique(new HloInstruction(HloOpcode::kPartitionId, shape));
1106 }
1107
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)1108 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
1109 const Shape& infeed_shape, HloInstruction* token_operand,
1110 const string& config) {
1111 return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
1112 config);
1113 }
1114
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)1115 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
1116 const Shape& outfeed_shape, HloInstruction* operand,
1117 HloInstruction* token_operand, absl::string_view outfeed_config) {
1118 return absl::make_unique<HloOutfeedInstruction>(
1119 outfeed_shape, operand, token_operand, outfeed_config);
1120 }
1121
CreateSend(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)1122 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
1123 HloInstruction* operand, HloInstruction* token, int64 channel_id,
1124 bool is_host_transfer) {
1125 return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
1126 is_host_transfer);
1127 }
1128
CreateSendDone(HloInstruction * operand,bool is_host_transfer)1129 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
1130 HloInstruction* operand, bool is_host_transfer) {
1131 auto send_operand = DynCast<HloSendInstruction>(operand);
1132 CHECK(send_operand != nullptr)
1133 << "SendDone must take the context operand from Send";
1134 return absl::make_unique<HloSendDoneInstruction>(send_operand,
1135 is_host_transfer);
1136 }
1137
CreateRecv(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)1138 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
1139 const Shape& shape, HloInstruction* token, int64 channel_id,
1140 bool is_host_transfer) {
1141 return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
1142 is_host_transfer);
1143 }
1144
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)1145 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
1146 HloInstruction* operand, bool is_host_transfer) {
1147 auto recv_operand = DynCast<HloRecvInstruction>(operand);
1148 CHECK(recv_operand != nullptr)
1149 << "RecvDone must take the context operand from Recv";
1150 return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
1151 is_host_transfer);
1152 }
1153
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1154 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
1155 const Shape& shape, HloInstruction* operand,
1156 absl::Span<const int64> dimensions) {
1157 return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
1158 }
1159
CreateAfterAll(absl::Span<HloInstruction * const> operands)1160 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
1161 absl::Span<HloInstruction* const> operands) {
1162 CHECK(!operands.empty());
1163 auto instruction = absl::WrapUnique(
1164 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1165 for (auto operand : operands) {
1166 instruction->AppendOperand(operand);
1167 }
1168 return instruction;
1169 }
1170
CreateToken()1171 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
1172 return absl::WrapUnique(
1173 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
1174 }
1175
1176 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)1177 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
1178 HloInstruction* token_operand) {
1179 auto instruction = absl::WrapUnique(
1180 new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
1181 instruction->AppendOperand(data_operand);
1182 instruction->AppendOperand(token_operand);
1183 return instruction;
1184 }
1185
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)1186 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
1187 const Shape& shape, HloComputation* condition, HloComputation* body,
1188 HloInstruction* init) {
1189 auto instruction =
1190 absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
1191 instruction->AppendOperand(init);
1192 // Body comes before condition computation in the vector.
1193 instruction->called_computations_.push_back(body);
1194 instruction->called_computations_.push_back(condition);
1195 return instruction;
1196 }
1197
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)1198 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1199 const Shape& shape, HloInstruction* pred,
1200 HloInstruction* true_computation_arg, HloComputation* true_computation,
1201 HloInstruction* false_computation_arg, HloComputation* false_computation) {
1202 auto instruction =
1203 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1204 instruction->AppendOperand(pred);
1205 instruction->AppendOperand(true_computation_arg);
1206 instruction->AppendOperand(false_computation_arg);
1207 // In called_computations_, the index of true_computation must be 0 and that
1208 // of false computation must be 1, as defined by kTrueComputationIndex and
1209 // kFalseComputationIndex.
1210 instruction->called_computations_.push_back(true_computation);
1211 instruction->called_computations_.push_back(false_computation);
1212 return instruction;
1213 }
1214
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)1215 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
1216 const Shape& shape, HloInstruction* branch_index,
1217 absl::Span<HloComputation* const> branch_computations,
1218 absl::Span<HloInstruction* const> branch_computation_args) {
1219 auto instruction =
1220 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
1221 instruction->AppendOperand(branch_index);
1222 CHECK_EQ(branch_computations.size(), branch_computation_args.size());
1223 for (int i = 0; i < branch_computations.size(); ++i) {
1224 instruction->called_computations_.push_back(branch_computations[i]);
1225 instruction->AppendOperand(branch_computation_args[i]);
1226 }
1227 return instruction;
1228 }
1229
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)1230 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
1231 const Shape& shape, HloInstruction* operand,
1232 absl::Span<const int64> start_indices,
1233 absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
1234 return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
1235 limit_indices, strides);
1236 }
1237
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)1238 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
1239 const Shape& shape, HloInstruction* operand,
1240 absl::Span<HloInstruction* const> start_indices,
1241 absl::Span<const int64> slice_sizes) {
1242 return absl::make_unique<HloDynamicSliceInstruction>(
1243 shape, operand, start_indices, slice_sizes);
1244 }
1245
1246 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1247 HloInstruction::CreateDynamicUpdateSlice(
1248 const Shape& shape, HloInstruction* operand, HloInstruction* update,
1249 absl::Span<HloInstruction* const> start_indices) {
1250 return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1251 shape, operand, update, start_indices);
1252 }
1253
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)1254 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1255 const Shape& shape, absl::Span<HloInstruction* const> operands,
1256 int64 dimension) {
1257 return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1258 dimension);
1259 }
1260
CreateConvert(const Shape & shape,HloInstruction * operand)1261 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1262 const Shape& shape, HloInstruction* operand) {
1263 auto instruction =
1264 absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1265 instruction->AppendOperand(operand);
1266 return instruction;
1267 }
1268
1269 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1270 HloInstruction::CreateBitcastConvert(const Shape& shape,
1271 HloInstruction* operand) {
1272 auto instruction =
1273 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1274 instruction->AppendOperand(operand);
1275 return instruction;
1276 }
1277
CreateBitcast(const Shape & shape,HloInstruction * operand)1278 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBitcast(
1279 const Shape& shape, HloInstruction* operand) {
1280 auto instruction =
1281 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcast, shape));
1282 instruction->AppendOperand(operand);
1283 return instruction;
1284 }
1285
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1286 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1287 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1288 absl::Span<const int64> dimensions_to_reduce,
1289 HloComputation* reduce_computation) {
1290 auto instruction = absl::WrapUnique(new HloReduceInstruction(
1291 shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1292 return std::move(instruction);
1293 }
1294
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1295 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1296 const Shape& shape, absl::Span<HloInstruction* const> operands,
1297 absl::Span<HloInstruction* const> init_values,
1298 absl::Span<const int64> dimensions_to_reduce,
1299 HloComputation* reduce_computation) {
1300 std::vector<HloInstruction*> all_args;
1301 all_args.reserve(operands.size() * 2);
1302 all_args.insert(all_args.end(), operands.begin(), operands.end());
1303 all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1304 return absl::make_unique<HloReduceInstruction>(
1305 shape, all_args, dimensions_to_reduce, reduce_computation);
1306 }
1307
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1308 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1309 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1310 const Window& window, HloComputation* reduce_computation) {
1311 return absl::make_unique<HloReduceWindowInstruction>(
1312 shape, operand, init_value, window, reduce_computation);
1313 }
1314
CreateReduceWindow(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,const Window & window,HloComputation * reduce_computation)1315 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1316 const Shape& shape, absl::Span<HloInstruction* const> operands,
1317 absl::Span<HloInstruction* const> init_values, const Window& window,
1318 HloComputation* reduce_computation) {
1319 return absl::make_unique<HloReduceWindowInstruction>(
1320 shape, operands, init_values, window, reduce_computation);
1321 }
1322 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)1323 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1324 HloInstruction* operand,
1325 HloInstruction* scale,
1326 HloInstruction* offset, float epsilon,
1327 int64 feature_index) {
1328 return absl::make_unique<HloBatchNormTrainingInstruction>(
1329 shape, operand, scale, offset, epsilon, feature_index);
1330 }
1331
1332 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)1333 HloInstruction::CreateBatchNormInference(
1334 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1335 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1336 float epsilon, int64 feature_index) {
1337 return absl::make_unique<HloBatchNormInferenceInstruction>(
1338 shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1339 }
1340
1341 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)1342 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1343 HloInstruction* scale, HloInstruction* mean,
1344 HloInstruction* variance,
1345 HloInstruction* grad_output, float epsilon,
1346 int64 feature_index) {
1347 return absl::make_unique<HloBatchNormGradInstruction>(
1348 shape, operand, scale, mean, variance, grad_output, epsilon,
1349 feature_index);
1350 }
1351
1352 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1353 HloInstruction::CreateSelectAndScatter(
1354 const Shape& shape, HloInstruction* operand, HloComputation* select,
1355 const Window& window, HloInstruction* source, HloInstruction* init_value,
1356 HloComputation* scatter) {
1357 return absl::make_unique<HloSelectAndScatterInstruction>(
1358 shape, operand, select, window, source, init_value, scatter);
1359 }
1360
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1361 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1362 const Shape& shape, HloInstruction* operand,
1363 absl::Span<const int64> broadcast_dimensions) {
1364 return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1365 broadcast_dimensions);
1366 }
1367
1368 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64 dimension)1369 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1370 HloInstruction* operand,
1371 int64 dimension) {
1372 return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1373 dimension);
1374 }
1375
1376 /* static */ std::unique_ptr<HloInstruction>
CreateSetDimensionSize(const Shape & shape,HloInstruction * operand,HloInstruction * val,int64 dimension)1377 HloInstruction::CreateSetDimensionSize(const Shape& shape,
1378 HloInstruction* operand,
1379 HloInstruction* val, int64 dimension) {
1380 return absl::make_unique<HloSetDimensionSizeInstruction>(shape, operand, val,
1381 dimension);
1382 }
1383
1384 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1385 HloInstruction::CreateBroadcastSequence(
1386 const Shape& output_shape, HloInstruction* operand,
1387 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1388 adder) {
1389 CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1390 operand->shape().rank() == output_shape.rank());
1391 Shape broadcast_shape = ShapeUtil::ChangeElementType(
1392 output_shape, operand->shape().element_type());
1393 // Do explicit broadcast for scalar.
1394 if (ShapeUtil::IsScalar(operand->shape())) {
1395 auto broadcast =
1396 HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1397 broadcast->set_metadata(operand->metadata());
1398 if (operand->has_sharding()) {
1399 broadcast->set_sharding(operand->sharding());
1400 }
1401 broadcast->set_frontend_attributes(operand->frontend_attributes());
1402 return broadcast;
1403 }
1404 // Do explicit broadcast for degenerate broadcast.
1405 std::vector<int64> broadcast_dimensions;
1406 std::vector<int64> reshaped_dimensions;
1407 for (int i = 0; i < operand->shape().rank(); i++) {
1408 if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1409 broadcast_dimensions.push_back(i);
1410 reshaped_dimensions.push_back(operand->shape().dimensions(i));
1411 } else {
1412 CHECK_EQ(operand->shape().dimensions(i), 1)
1413 << "An explicit broadcast sequence requires the broadcasted "
1414 "dimensions to be trivial; operand: "
1415 << operand->ToString() << "; output_shape: " << output_shape;
1416 }
1417 }
1418 // Eliminate the size one dimensions.
1419 HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1420 ShapeUtil::MakeShape(operand->shape().element_type(),
1421 reshaped_dimensions),
1422 operand));
1423 reshaped_operand->set_metadata(operand->metadata());
1424 if (operand->has_sharding()) {
1425 reshaped_operand->set_sharding(operand->sharding());
1426 }
1427 reshaped_operand->set_frontend_attributes(operand->frontend_attributes());
1428 // Broadcast 'reshape' up to the larger size.
1429 auto broadcast = HloInstruction::CreateBroadcast(
1430 broadcast_shape, reshaped_operand, broadcast_dimensions);
1431 broadcast->set_metadata(operand->metadata());
1432 if (operand->has_sharding()) {
1433 broadcast->set_sharding(operand->sharding());
1434 }
1435 broadcast->set_frontend_attributes(operand->frontend_attributes());
1436 return broadcast;
1437 }
1438
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1439 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1440 const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1441 const PaddingConfig& padding_config) {
1442 return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1443 padding_config);
1444 }
1445
CreateReshape(const Shape & shape,HloInstruction * operand,int64 inferred_dimension)1446 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1447 const Shape& shape, HloInstruction* operand, int64 inferred_dimension) {
1448 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1449 ShapeUtil::ElementsIn(operand->shape()))
1450 << "shape: " << ShapeUtil::HumanString(shape)
1451 << " operand: " << ShapeUtil::HumanString(operand->shape());
1452
1453 return absl::make_unique<HloReshapeInstruction>(shape, operand,
1454 inferred_dimension);
1455 }
1456
1457 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicReshape(const Shape & shape,HloInstruction * data_operand,absl::Span<HloInstruction * const> dim_sizes)1458 HloInstruction::CreateDynamicReshape(
1459 const Shape& shape, HloInstruction* data_operand,
1460 absl::Span<HloInstruction* const> dim_sizes) {
1461 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1462 ShapeUtil::ElementsIn(data_operand[0].shape()))
1463 << "shape: " << ShapeUtil::HumanString(shape)
1464 << " operand: " << ShapeUtil::HumanString(data_operand[0].shape());
1465 CHECK_EQ(shape.rank(), dim_sizes.size());
1466 return absl::make_unique<HloDynamicReshapeInstruction>(shape, data_operand,
1467 dim_sizes);
1468 }
1469
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1470 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1471 const Shape& shape, HloInstruction* operand,
1472 absl::Span<const int64> dimensions) {
1473 return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1474 }
1475
CreateSort(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1476 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1477 const Shape& shape, int64 dimension,
1478 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1479 bool is_stable) {
1480 return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1481 compare, is_stable);
1482 }
1483
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1484 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1485 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1486 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1487 fused_root);
1488 }
1489
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1490 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1491 const Shape& shape, FusionKind fusion_kind,
1492 absl::Span<HloInstruction* const> operands,
1493 HloComputation* fusion_computation) {
1494 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1495 fusion_computation);
1496 }
1497
set_single_sharding(const HloSharding & sharding)1498 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1499 CHECK(!sharding.IsTuple()) << sharding;
1500 if (shape().IsTuple()) {
1501 set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1502 } else {
1503 set_sharding(sharding);
1504 }
1505 }
1506
SetupDerivedInstruction(HloInstruction * derived_instruction) const1507 void HloInstruction::SetupDerivedInstruction(
1508 HloInstruction* derived_instruction) const {
1509 if (sharding_ != nullptr &&
1510 ShapeUtil::CompatibleKind(shape_, derived_instruction->shape())) {
1511 // Only copy sharding if the tuple tree shape of the two instruction is
1512 // compatible because copying it between differently shaped instructions
1513 // can produce invalid shardings.
1514 derived_instruction->set_sharding(*sharding_);
1515 } else {
1516 derived_instruction->clear_sharding();
1517 }
1518 derived_instruction->set_metadata(metadata_);
1519 derived_instruction->set_frontend_attributes(frontend_attributes_);
1520 }
1521
HasSideEffectNoRecurse() const1522 bool HloInstruction::HasSideEffectNoRecurse() const {
1523 switch (opcode_) {
1524 case HloOpcode::kSend:
1525 case HloOpcode::kSendDone:
1526 case HloOpcode::kRecv:
1527 case HloOpcode::kRecvDone:
1528 case HloOpcode::kRng:
1529 case HloOpcode::kRngGetAndUpdateState:
1530 case HloOpcode::kInfeed:
1531 case HloOpcode::kOutfeed:
1532 case HloOpcode::kTrace:
1533 return true;
1534 case HloOpcode::kAllReduce:
1535 return channel_id().has_value() ||
1536 Cast<HloAllReduceInstruction>(this)->constrain_layout();
1537 case HloOpcode::kAllToAll:
1538 return Cast<HloAllToAllInstruction>(this)->constrain_layout();
1539 case HloOpcode::kCustomCall:
1540 return Cast<HloCustomCallInstruction>(this)
1541 ->custom_call_has_side_effect();
1542 default:
1543 return false;
1544 }
1545 }
1546
HasSideEffect() const1547 bool HloInstruction::HasSideEffect() const {
1548 if (HasSideEffectNoRecurse()) {
1549 return true;
1550 }
1551 // Check if any of the called computations has a side effect.
1552 for (const auto& computation : called_computations()) {
1553 if (computation->HasSideEffect()) {
1554 return true;
1555 }
1556 }
1557 return false;
1558 }
1559
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1560 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1561 const Shape& shape, absl::Span<HloInstruction* const> operands,
1562 HloComputation* computation) {
1563 std::unique_ptr<HloInstruction> instruction =
1564 absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1565 for (auto operand : operands) {
1566 instruction->AppendOperand(operand);
1567 }
1568 instruction->called_computations_.push_back(computation);
1569 return instruction;
1570 }
1571
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,string opaque)1572 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1573 const Shape& shape, absl::Span<HloInstruction* const> operands,
1574 absl::string_view custom_call_target, string opaque) {
1575 return absl::make_unique<HloCustomCallInstruction>(
1576 shape, operands, custom_call_target, std::move(opaque));
1577 }
1578
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * to_apply,absl::string_view custom_call_target,string opaque)1579 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1580 const Shape& shape, absl::Span<HloInstruction* const> operands,
1581 HloComputation* to_apply, absl::string_view custom_call_target,
1582 string opaque) {
1583 return absl::make_unique<HloCustomCallInstruction>(
1584 shape, operands, to_apply, custom_call_target, std::move(opaque));
1585 }
1586
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloComputation * const> called_computations,absl::string_view custom_call_target,string opaque)1587 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1588 const Shape& shape, absl::Span<HloInstruction* const> operands,
1589 absl::Span<HloComputation* const> called_computations,
1590 absl::string_view custom_call_target, string opaque) {
1591 return absl::make_unique<HloCustomCallInstruction>(
1592 shape, operands, called_computations, custom_call_target,
1593 std::move(opaque));
1594 }
1595
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,string opaque)1596 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1597 const Shape& shape, absl::Span<HloInstruction* const> operands,
1598 absl::string_view custom_call_target,
1599 absl::Span<const Shape> operand_shapes_with_layout, string opaque) {
1600 return absl::make_unique<HloCustomCallInstruction>(
1601 shape, operands, custom_call_target, std::move(opaque),
1602 operand_shapes_with_layout);
1603 }
1604
CreateTuple(absl::Span<HloInstruction * const> elements)1605 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1606 absl::Span<HloInstruction* const> elements) {
1607 std::vector<Shape> element_shapes;
1608 for (auto element : elements) {
1609 element_shapes.push_back(element->shape());
1610 }
1611 Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1612 return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1613 }
1614
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes,bool indices_are_sorted)1615 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1616 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1617 const GatherDimensionNumbers& gather_dim_numbers,
1618 absl::Span<const int64> slice_sizes, bool indices_are_sorted) {
1619 return absl::make_unique<HloGatherInstruction>(
1620 shape, operand, start_indices, gather_dim_numbers, slice_sizes,
1621 indices_are_sorted);
1622 }
1623
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers,bool indices_are_sorted,bool unique_indices)1624 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1625 const Shape& shape, HloInstruction* operand,
1626 HloInstruction* scatter_indices, HloInstruction* updates,
1627 HloComputation* update_computation,
1628 const ScatterDimensionNumbers& scatter_dim_numbers, bool indices_are_sorted,
1629 bool unique_indices) {
1630 return absl::make_unique<HloScatterInstruction>(
1631 shape, operand, scatter_indices, updates, update_computation,
1632 scatter_dim_numbers, indices_are_sorted, unique_indices);
1633 }
1634
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1635 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1636 const Shape& shape, HloInstruction* operand,
1637 std::unique_ptr<DomainMetadata> operand_side_metadata,
1638 std::unique_ptr<DomainMetadata> user_side_metadata) {
1639 return absl::make_unique<HloDomainInstruction>(
1640 shape, operand, std::move(operand_side_metadata),
1641 std::move(user_side_metadata));
1642 }
1643
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1644 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1645 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1646 HloCloneContext* context) const {
1647 VLOG(3) << "CloneWithNewOperands:\n " << ToString();
1648 VLOG(3) << " new operands:";
1649 for (const HloInstruction* new_operand : new_operands) {
1650 VLOG(3) << " %" << new_operand->name();
1651 }
1652
1653 std::unique_ptr<HloInstruction> clone;
1654 // Explicitly call the factory for the instruction type. This is more robust
1655 // in the face of code changes than copying fields explicitly. This also
1656 // properly sets the user fields of the operands.
1657 switch (opcode_) {
1658 // Ops migrated to subclasses.
1659 // TODO(b/80131774): Remove this switch when migration is complete.
1660 case HloOpcode::kBatchNormTraining:
1661 case HloOpcode::kBatchNormInference:
1662 case HloOpcode::kBatchNormGrad:
1663 case HloOpcode::kFft:
1664 case HloOpcode::kCompare:
1665 case HloOpcode::kCopyStart:
1666 case HloOpcode::kSend:
1667 case HloOpcode::kSendDone:
1668 case HloOpcode::kRecv:
1669 case HloOpcode::kRecvDone:
1670 case HloOpcode::kReverse:
1671 case HloOpcode::kConcatenate:
1672 case HloOpcode::kReduce:
1673 case HloOpcode::kTranspose:
1674 case HloOpcode::kBroadcast:
1675 case HloOpcode::kReshape:
1676 case HloOpcode::kDynamicReshape:
1677 case HloOpcode::kMap:
1678 case HloOpcode::kSlice:
1679 case HloOpcode::kConstant:
1680 case HloOpcode::kTrace:
1681 case HloOpcode::kFusion:
1682 case HloOpcode::kRng:
1683 case HloOpcode::kRngBitGenerator:
1684 case HloOpcode::kRngGetAndUpdateState:
1685 case HloOpcode::kParameter:
1686 case HloOpcode::kGetTupleElement:
1687 case HloOpcode::kReducePrecision:
1688 case HloOpcode::kAllGather:
1689 case HloOpcode::kAllReduce:
1690 case HloOpcode::kAllToAll:
1691 case HloOpcode::kCollectivePermute:
1692 case HloOpcode::kCollectivePermuteStart:
1693 case HloOpcode::kInfeed:
1694 case HloOpcode::kOutfeed:
1695 case HloOpcode::kConvolution:
1696 case HloOpcode::kCustomCall:
1697 case HloOpcode::kReduceWindow:
1698 case HloOpcode::kSelectAndScatter:
1699 case HloOpcode::kPad:
1700 case HloOpcode::kDynamicSlice:
1701 case HloOpcode::kSort:
1702 case HloOpcode::kGather:
1703 case HloOpcode::kScatter:
1704 case HloOpcode::kIota:
1705 case HloOpcode::kDot:
1706 case HloOpcode::kDomain:
1707 case HloOpcode::kGetDimensionSize:
1708 case HloOpcode::kSetDimensionSize:
1709 case HloOpcode::kTriangularSolve:
1710 case HloOpcode::kCholesky:
1711 clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1712 break;
1713 // Unary ops.
1714 case HloOpcode::kAbs:
1715 case HloOpcode::kRoundNearestAfz:
1716 case HloOpcode::kBitcast:
1717 case HloOpcode::kCeil:
1718 case HloOpcode::kClz:
1719 case HloOpcode::kCollectivePermuteDone:
1720 case HloOpcode::kCopy:
1721 case HloOpcode::kCopyDone:
1722 case HloOpcode::kCos:
1723 case HloOpcode::kExp:
1724 case HloOpcode::kExpm1:
1725 case HloOpcode::kImag:
1726 case HloOpcode::kIsFinite:
1727 case HloOpcode::kFloor:
1728 case HloOpcode::kLog:
1729 case HloOpcode::kLog1p:
1730 case HloOpcode::kNot:
1731 case HloOpcode::kNegate:
1732 case HloOpcode::kPopulationCount:
1733 case HloOpcode::kReal:
1734 case HloOpcode::kRsqrt:
1735 case HloOpcode::kLogistic:
1736 case HloOpcode::kSign:
1737 case HloOpcode::kSin:
1738 case HloOpcode::kSqrt:
1739 case HloOpcode::kCbrt:
1740 case HloOpcode::kTanh:
1741 CHECK_EQ(new_operands.size(), 1);
1742 clone = CreateUnary(shape, opcode_, new_operands[0]);
1743 break;
1744 // Binary ops.
1745 case HloOpcode::kAdd:
1746 case HloOpcode::kAtan2:
1747 case HloOpcode::kComplex:
1748 case HloOpcode::kDivide:
1749 case HloOpcode::kMultiply:
1750 case HloOpcode::kSubtract:
1751 case HloOpcode::kMaximum:
1752 case HloOpcode::kMinimum:
1753 case HloOpcode::kPower:
1754 case HloOpcode::kRemainder:
1755 case HloOpcode::kAnd:
1756 case HloOpcode::kOr:
1757 case HloOpcode::kXor:
1758 case HloOpcode::kShiftLeft:
1759 case HloOpcode::kShiftRightArithmetic:
1760 case HloOpcode::kShiftRightLogical:
1761 CHECK_EQ(new_operands.size(), 2);
1762 clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1763 break;
1764 // Ternary ops.
1765 case HloOpcode::kClamp:
1766 case HloOpcode::kSelect:
1767 case HloOpcode::kTupleSelect:
1768 CHECK_EQ(new_operands.size(), 3);
1769 clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1770 new_operands[2]);
1771 break;
1772 // Other supported ops.
1773 case HloOpcode::kCall:
1774 clone = CreateCall(shape, new_operands, to_apply());
1775 break;
1776 case HloOpcode::kConvert:
1777 CHECK_EQ(new_operands.size(), 1);
1778 clone = CreateConvert(shape, new_operands[0]);
1779 break;
1780 case HloOpcode::kBitcastConvert:
1781 CHECK_EQ(new_operands.size(), 1);
1782 clone = CreateBitcastConvert(shape, new_operands[0]);
1783 break;
1784 case HloOpcode::kDynamicUpdateSlice:
1785 clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1786 new_operands.subspan(2));
1787 break;
1788 case HloOpcode::kTuple:
1789 clone = CreateTuple(new_operands);
1790 *clone->mutable_shape() = shape;
1791 break;
1792 case HloOpcode::kWhile:
1793 CHECK_EQ(new_operands.size(), 1);
1794 clone =
1795 CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1796 break;
1797 case HloOpcode::kConditional:
1798 CHECK_EQ(new_operands.size(), branch_count() + 1);
1799 clone = CreateConditional(shape, new_operands[0],
1800 absl::MakeSpan(branch_computations()),
1801 new_operands.subspan(1));
1802 break;
1803 case HloOpcode::kAfterAll:
1804 if (new_operands.empty()) {
1805 clone = CreateToken();
1806 } else {
1807 clone = CreateAfterAll(new_operands);
1808 }
1809 break;
1810 case HloOpcode::kAddDependency:
1811 CHECK_EQ(new_operands.size(), 2);
1812 clone = CreateAddDependency(new_operands[0], new_operands[1]);
1813 break;
1814 case HloOpcode::kReplicaId:
1815 CHECK_EQ(new_operands.size(), 0);
1816 clone = CreateReplicaId(shape);
1817 break;
1818 case HloOpcode::kPartitionId:
1819 CHECK_EQ(new_operands.size(), 0);
1820 clone = CreatePartitionId(shape);
1821 break;
1822 }
1823 // SetupDerivedInstruction will setup the precision_config_ field.
1824 SetupDerivedInstruction(clone.get());
1825 clone->set_parent(parent_);
1826 clone->set_outer_dimension_partitions(outer_dimension_partitions_);
1827 clone->set_raw_backend_config_string(backend_config_);
1828 if (context != nullptr) {
1829 context->MapInstruction(this, clone.get());
1830 clone->ReplaceCalledComputations([&](HloComputation* callee) {
1831 return callee->parent() != context->module()
1832 ? context->module()->DeepCloneComputation(callee, context)
1833 : callee;
1834 });
1835 }
1836 return clone;
1837 }
1838
DetachFromOperandsAndUsers()1839 void HloInstruction::DetachFromOperandsAndUsers() {
1840 if (cleaned_up_) {
1841 return;
1842 }
1843 cleaned_up_ = true;
1844 // Detach from operands. An instruction may be repeated as an operand. To
1845 // avoid calling RemoveUser twice on the same operand, check before remove.
1846 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1847 HloInstruction* operand = operands_[operand_num];
1848 if (operand == nullptr) {
1849 continue;
1850 }
1851 if (operand->user_map_.find(this) != operand->user_map_.end()) {
1852 operand->RemoveUser(this);
1853 }
1854 operands_[operand_num] = nullptr;
1855 }
1856
1857 // Update users. Set `nullptr` to the corresponding operand slot for users.
1858 for (auto& user : this->users()) {
1859 for (int i = 0; i < user->operand_count(); ++i) {
1860 if (user->operands_[i] == this) {
1861 user->operands_[i] = nullptr;
1862 }
1863 }
1864 }
1865 }
1866
CloneWithNewShape(const Shape & shape,const string & suffix,HloCloneContext * context) const1867 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewShape(
1868 const Shape& shape, const string& suffix, HloCloneContext* context) const {
1869 std::unique_ptr<HloInstruction> clone =
1870 CloneWithNewOperands(shape, operands_, context);
1871 if (suffix.empty()) {
1872 clone->name_ = name();
1873 } else {
1874 // If an instruction is cloned multiple times avoid names like
1875 // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1876 // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1877 // clone of foo.suffix2 is named foo.suffix3 and so on.
1878 const string dot_suffix = "." + suffix;
1879 size_t index = name().rfind(dot_suffix);
1880 if (index == string::npos) {
1881 // Existing name does not include ".suffix".
1882 clone->name_ = name() + dot_suffix;
1883 } else {
1884 // Existing name includes ".suffix". Determine if substring after
1885 // ".suffix" is numeric and should be replaced with an incremented number.
1886 string after_suffix = name().substr(index + dot_suffix.size());
1887 if (after_suffix.empty()) {
1888 // Existing name ends in ".suffix". New name should end in ".suffix2".
1889 clone->name_ = name() + "2";
1890 } else {
1891 // If names ends with .suffix[0-9]+ then replace with a suffix with the
1892 // numeric value incremented.
1893 int64 numeric_suffix;
1894 if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
1895 clone->name_ =
1896 StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1897 } else {
1898 // Substring after ".suffix" is non-numeric.
1899 clone->name_ = name() + dot_suffix;
1900 }
1901 }
1902 }
1903 }
1904 return clone;
1905 }
1906
Clone(const string & suffix,HloCloneContext * context) const1907 std::unique_ptr<HloInstruction> HloInstruction::Clone(
1908 const string& suffix, HloCloneContext* context) const {
1909 std::unique_ptr<HloInstruction> clone =
1910 CloneWithNewShape(shape_, suffix, context);
1911 return clone;
1912 }
1913
1914 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1915 HloInstruction::LatestNonGteAncestorAndIndex() const {
1916 const HloInstruction* hlo = this;
1917 ShapeIndex index;
1918 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1919 index.push_back(hlo->tuple_index());
1920 hlo = hlo->operand(0);
1921 }
1922
1923 // We built up index in the reverse order from what we want.
1924 std::reverse(index.begin(), index.end());
1925
1926 return {hlo, index};
1927 }
1928
LatestNonGteAncestor() const1929 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1930 const HloInstruction* hlo = this;
1931 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1932 hlo = hlo->operand(0);
1933 }
1934 return hlo;
1935 }
1936
operand(int64 i) const1937 const HloInstruction* HloInstruction::operand(int64 i) const {
1938 return operands_.at(i);
1939 }
1940
mutable_operand(int64 i)1941 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1942 CHECK(operands_[i] != nullptr);
1943 return operands_.at(i);
1944 }
1945
operand_index(const HloInstruction * target) const1946 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1947 for (int64 i = 0; i < operand_count(); ++i) {
1948 if (target == operand(i)) {
1949 return i;
1950 }
1951 }
1952 LOG(FATAL) << "target was not an operand: " << target->ToString();
1953 }
1954
unique_operands() const1955 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
1956 InstructionVector unique;
1957 absl::flat_hash_set<const HloInstruction*> seen;
1958 for (HloInstruction* operand : operands()) {
1959 if (seen.insert(operand).second) {
1960 unique.push_back(operand);
1961 }
1962 }
1963 return unique;
1964 }
1965
AddControlDependencyTo(HloInstruction * instruction)1966 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1967 TF_RET_CHECK(instruction->parent() == parent());
1968 if (!absl::c_linear_search(control_successors_, instruction)) {
1969 control_successors_.push_back(instruction);
1970 TF_RET_CHECK(
1971 !absl::c_linear_search(instruction->control_predecessors_, this));
1972 instruction->control_predecessors_.push_back(this);
1973 }
1974 return Status::OK();
1975 }
1976
RemoveControlDependencyTo(HloInstruction * instruction)1977 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1978 TF_RET_CHECK(instruction->parent() == parent());
1979 TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
1980 TF_RETURN_IF_ERROR(
1981 EraseElementFromVector(&instruction->control_predecessors_, this));
1982 return Status::OK();
1983 }
1984
DropAllControlDeps()1985 Status HloInstruction::DropAllControlDeps() {
1986 for (auto* ctrl_succ : control_successors_) {
1987 TF_RETURN_IF_ERROR(
1988 EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
1989 }
1990 for (auto* ctrl_pred : control_predecessors_) {
1991 TF_RETURN_IF_ERROR(
1992 EraseElementFromVector(&ctrl_pred->control_successors_, this));
1993 }
1994 control_successors_.clear();
1995 control_predecessors_.clear();
1996 return Status::OK();
1997 }
1998
CopyAllControlDepsFrom(const HloInstruction * inst)1999 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
2000 for (auto* ctrl_pred : inst->control_predecessors()) {
2001 TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
2002 }
2003
2004 for (auto* ctrl_succ : inst->control_successors()) {
2005 TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
2006 }
2007
2008 return Status::OK();
2009 }
2010
IdenticalInternal(const HloInstruction & other,const std::function<bool (const HloInstruction *,const HloInstruction *)> & eq_operands,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations,bool layout_sensitive,bool ignore_channel_id_values) const2011 bool HloInstruction::IdenticalInternal(
2012 const HloInstruction& other,
2013 const std::function<bool(const HloInstruction*, const HloInstruction*)>&
2014 eq_operands,
2015 const std::function<bool(const HloComputation*, const HloComputation*)>&
2016 eq_computations,
2017 bool layout_sensitive, bool ignore_channel_id_values) const {
2018 // An instruction is always identical to itself.
2019 if (this == &other) {
2020 return true;
2021 }
2022
2023 // Identical instruction must have the same opcode, shape, and identical
2024 // operands.
2025 if (opcode() != other.opcode()) {
2026 return false;
2027 }
2028 if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
2029 : ShapeUtil::Compatible(shape(), other.shape()))) {
2030 return false;
2031 }
2032 if (operands().size() != other.operands().size()) {
2033 return false;
2034 }
2035
2036 // Use an explicit loop rather than ContainerEquals, because copying around
2037 // std::functions may be too expensive in some cases.
2038 for (size_t i = 0; i < operands().size(); ++i) {
2039 if (!eq_operands(operand(i), other.operand(i))) {
2040 return false;
2041 }
2042 }
2043
2044 if (backend_config_ != other.backend_config_) {
2045 return false;
2046 }
2047
2048 if (ignore_channel_id_values) {
2049 if (auto channel_inst = DynCast<HloChannelInstruction>(this)) {
2050 return channel_inst->IdenticalSlowPathIgnoringChannelIdValues(
2051 other, eq_computations);
2052 }
2053 }
2054 return IdenticalSlowPath(other, eq_computations);
2055 }
2056
AppendOperand(HloInstruction * operand)2057 void HloInstruction::AppendOperand(HloInstruction* operand) {
2058 if (operand->parent() != nullptr) {
2059 DCHECK(!operand->parent()->IsMarkedAsDead(operand))
2060 << "Operand " << operand->name() << " is already marked dead";
2061 }
2062 operands_.push_back(operand);
2063 operand->AddUser(this);
2064 }
2065
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)2066 void HloInstruction::RemoveOperandsAtAscendingIndices(
2067 absl::Span<const int> ascending_indices) {
2068 if (ascending_indices.empty()) {
2069 return;
2070 }
2071 int next_index = 0;
2072 int removed_count = 0;
2073 for (int to_remove : ascending_indices) {
2074 while (next_index < to_remove) {
2075 operands_[next_index - removed_count] = operands_[next_index];
2076 ++next_index;
2077 }
2078 CHECK_LT(to_remove, operands_.size());
2079 ++removed_count;
2080 ++next_index;
2081 }
2082 while (next_index < operands_.size()) {
2083 operands_[next_index - removed_count] = operands_[next_index];
2084 ++next_index;
2085 }
2086 CHECK_EQ(removed_count, ascending_indices.size());
2087 operands_.resize(operands_.size() - removed_count);
2088 }
2089
AddUser(HloInstruction * user)2090 void HloInstruction::AddUser(HloInstruction* user) {
2091 if (!ContainsKey(user_map_, user)) {
2092 user_map_.emplace(user, users_.size());
2093 users_.push_back(user);
2094 }
2095 }
2096
UserId(HloInstruction * user)2097 int64 HloInstruction::UserId(HloInstruction* user) {
2098 auto result = user_map_.find(user);
2099 CHECK(result != user_map_.end());
2100 return result->second;
2101 }
2102
HasConstantOperand() const2103 bool HloInstruction::HasConstantOperand() const {
2104 for (const HloInstruction* operand : operands_) {
2105 if (operand->IsConstant()) {
2106 return true;
2107 }
2108 }
2109 return false;
2110 }
2111
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const2112 bool HloInstruction::IdenticalSlowPath(
2113 const HloInstruction& other,
2114 const std::function<bool(const HloComputation*, const HloComputation*)>&
2115 eq_computations) const {
2116 // Perform opcode specific checks.
2117 switch (opcode()) {
2118 // The result of these instructions only depend upon their opcode and
2119 // operands.
2120 case HloOpcode::kAbs:
2121 case HloOpcode::kAtan2:
2122 case HloOpcode::kAdd:
2123 case HloOpcode::kBitcast:
2124 case HloOpcode::kBitcastConvert:
2125 case HloOpcode::kCeil:
2126 case HloOpcode::kClamp:
2127 case HloOpcode::kClz:
2128 case HloOpcode::kCollectivePermuteDone:
2129 case HloOpcode::kComplex:
2130 case HloOpcode::kConvert:
2131 case HloOpcode::kCopy:
2132 case HloOpcode::kCopyStart:
2133 case HloOpcode::kCopyDone:
2134 case HloOpcode::kCos:
2135 case HloOpcode::kDivide:
2136 case HloOpcode::kDynamicUpdateSlice:
2137 case HloOpcode::kExp:
2138 case HloOpcode::kExpm1:
2139 case HloOpcode::kFloor:
2140 case HloOpcode::kImag:
2141 case HloOpcode::kIsFinite:
2142 case HloOpcode::kLog:
2143 case HloOpcode::kLog1p:
2144 case HloOpcode::kAnd:
2145 case HloOpcode::kNot:
2146 case HloOpcode::kOr:
2147 case HloOpcode::kXor:
2148 case HloOpcode::kMaximum:
2149 case HloOpcode::kMinimum:
2150 case HloOpcode::kMultiply:
2151 case HloOpcode::kNegate:
2152 case HloOpcode::kPartitionId:
2153 case HloOpcode::kPopulationCount:
2154 case HloOpcode::kPower:
2155 case HloOpcode::kReal:
2156 case HloOpcode::kRemainder:
2157 case HloOpcode::kReshape:
2158 case HloOpcode::kDynamicReshape:
2159 case HloOpcode::kReplicaId:
2160 case HloOpcode::kRoundNearestAfz:
2161 case HloOpcode::kRsqrt:
2162 case HloOpcode::kSelect:
2163 case HloOpcode::kShiftLeft:
2164 case HloOpcode::kShiftRightArithmetic:
2165 case HloOpcode::kShiftRightLogical:
2166 case HloOpcode::kLogistic:
2167 case HloOpcode::kSign:
2168 case HloOpcode::kSin:
2169 case HloOpcode::kSqrt:
2170 case HloOpcode::kCbrt:
2171 case HloOpcode::kSubtract:
2172 case HloOpcode::kTanh:
2173 case HloOpcode::kTuple:
2174 case HloOpcode::kTupleSelect:
2175 return true;
2176
2177 // This opcode has complex or special behavior so just return false.
2178 case HloOpcode::kAfterAll:
2179 case HloOpcode::kAddDependency:
2180 return false;
2181
2182 // Remaining instructions with special values.
2183 case HloOpcode::kCall:
2184 return eq_computations(to_apply(), other.to_apply());
2185 case HloOpcode::kConditional:
2186 for (int j = 0; j < branch_count(); ++j) {
2187 if (!eq_computations(branch_computation(j),
2188 other.branch_computation(j))) {
2189 return false;
2190 }
2191 }
2192 return true;
2193 case HloOpcode::kWhile:
2194 return (eq_computations(while_body(), other.while_body()) &&
2195 eq_computations(while_condition(), other.while_condition()));
2196
2197 // Ops migrated to subclasses should never come to this line.
2198 // TODO(b/80131774): Remove this switch when migration is complete.
2199 case HloOpcode::kBatchNormTraining:
2200 case HloOpcode::kBatchNormInference:
2201 case HloOpcode::kBatchNormGrad:
2202 case HloOpcode::kFft:
2203 case HloOpcode::kCompare:
2204 case HloOpcode::kSend:
2205 case HloOpcode::kSendDone:
2206 case HloOpcode::kRecv:
2207 case HloOpcode::kRecvDone:
2208 case HloOpcode::kReverse:
2209 case HloOpcode::kConcatenate:
2210 case HloOpcode::kReduce:
2211 case HloOpcode::kSort:
2212 case HloOpcode::kTranspose:
2213 case HloOpcode::kBroadcast:
2214 case HloOpcode::kMap:
2215 case HloOpcode::kSlice:
2216 case HloOpcode::kConstant:
2217 case HloOpcode::kIota:
2218 case HloOpcode::kTrace:
2219 case HloOpcode::kFusion:
2220 case HloOpcode::kRng:
2221 case HloOpcode::kRngBitGenerator:
2222 case HloOpcode::kRngGetAndUpdateState:
2223 case HloOpcode::kParameter:
2224 case HloOpcode::kGetTupleElement:
2225 case HloOpcode::kReducePrecision:
2226 case HloOpcode::kInfeed:
2227 case HloOpcode::kOutfeed:
2228 case HloOpcode::kAllGather:
2229 case HloOpcode::kAllReduce:
2230 case HloOpcode::kAllToAll:
2231 case HloOpcode::kCollectivePermute:
2232 case HloOpcode::kCollectivePermuteStart:
2233 case HloOpcode::kConvolution:
2234 case HloOpcode::kCustomCall:
2235 case HloOpcode::kReduceWindow:
2236 case HloOpcode::kSelectAndScatter:
2237 case HloOpcode::kPad:
2238 case HloOpcode::kDynamicSlice:
2239 case HloOpcode::kGather:
2240 case HloOpcode::kScatter:
2241 case HloOpcode::kDot:
2242 case HloOpcode::kDomain:
2243 case HloOpcode::kGetDimensionSize:
2244 case HloOpcode::kSetDimensionSize:
2245 case HloOpcode::kTriangularSolve:
2246 case HloOpcode::kCholesky:
2247 LOG(FATAL) << "Base class impl called for opcode with subclass: "
2248 << opcode();
2249 }
2250 return false;
2251 }
2252
HashOperand(const HloInstruction * hlo)2253 static uint64 HashOperand(const HloInstruction* hlo) {
2254 return ShapeUtil::Hash(hlo->shape());
2255 }
2256
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const2257 uint64 HloInstruction::Hash(
2258 const std::function<uint64(const HloInstruction*)>& hash_operand) const {
2259 using tensorflow::Hash64Combine;
2260
2261 uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
2262 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
2263
2264 if (!IsCrossModuleAllReduce()) {
2265 if (!operands().empty()) {
2266 for (size_t i = 0; i < operands().size(); ++i) {
2267 hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
2268 }
2269 }
2270 }
2271
2272 hash_value = Hash64Combine(hash_value, InnerHash());
2273 return hash_value;
2274 }
2275
Hash() const2276 uint64 HloInstruction::Hash() const {
2277 // Use HashOperand as an argument to prevent non-termination.
2278 return Hash(HashOperand);
2279 }
2280
InnerHash() const2281 uint64 HloInstruction::InnerHash() const { return 13; }
2282
RemoveUser(HloInstruction * user)2283 void HloInstruction::RemoveUser(HloInstruction* user) {
2284 auto map_it = user_map_.find(user);
2285 CHECK(map_it != user_map_.end());
2286
2287 const int64 index = map_it->second;
2288 CHECK_EQ(users_[index], user);
2289
2290 // Move the last user into the position of the removed user.
2291 users_[index] = users_.back();
2292 user_map_[users_.back()] = index;
2293
2294 // Remove the user from the map and drop the last slot from the vector what
2295 // have been moved to the position of the original user.
2296 user_map_.erase(map_it);
2297 users_.pop_back();
2298 }
2299
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)2300 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
2301 HloInstruction* new_producer) {
2302 TF_RET_CHECK(
2303 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2304 << "this shape: " << ShapeUtil::HumanString(shape())
2305 << ", replacement shape: "
2306 << ShapeUtil::HumanString(new_producer->shape());
2307 return ReplaceUseWithDifferentShape(user, new_producer);
2308 }
2309
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)2310 Status HloInstruction::ReplaceUseWithDifferentShape(
2311 HloInstruction* user, HloInstruction* new_producer) {
2312 VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
2313 << " with " << new_producer->name();
2314
2315 RemoveUser(user);
2316
2317 TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
2318 std::replace(user->operands_.begin(), user->operands_.end(), this,
2319 new_producer);
2320 new_producer->AddUser(user);
2321 // Custom fusions may not be able to handle deduplicated operands.
2322 if (user->opcode() == HloOpcode::kFusion) {
2323 TF_RETURN_IF_ERROR(
2324 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2325 }
2326 return Status::OK();
2327 }
2328
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)2329 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
2330 HloInstruction* new_operand) {
2331 auto old_operand = operand(operand_num);
2332 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
2333 new_operand->shape()))
2334 << old_operand->shape() << " is not compatible with "
2335 << new_operand->shape();
2336 return ReplaceOperandWithDifferentShape(operand_num, new_operand);
2337 }
2338
ReplaceOperandWithDifferentShape(int64 operand_num,HloInstruction * new_operand)2339 Status HloInstruction::ReplaceOperandWithDifferentShape(
2340 int64 operand_num, HloInstruction* new_operand) {
2341 TF_RET_CHECK(operand_num >= 0);
2342 TF_RET_CHECK(operand_num < operand_count());
2343 HloInstruction* old_operand = mutable_operand(operand_num);
2344 if (old_operand == new_operand) {
2345 return Status::OK();
2346 }
2347
2348 operands_[operand_num] = new_operand;
2349
2350 VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
2351 << new_operand->name() << ", was " << old_operand->name();
2352
2353 if (!absl::c_linear_search(operands_, old_operand)) {
2354 old_operand->RemoveUser(this);
2355 }
2356 new_operand->AddUser(this);
2357 return Status::OK();
2358 }
2359
ReplaceUsesWith(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2360 Status HloInstruction::ReplaceUsesWith(absl::Span<HloInstruction* const> users,
2361 HloInstruction* new_producer) {
2362 TF_RET_CHECK(
2363 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2364 << shape() << " is not compatible with " << new_producer->shape();
2365 return ReplaceAllUsesWithDifferentShape(users, new_producer);
2366 }
2367
ReplaceAllUsesWithDifferentShape(absl::Span<HloInstruction * const> users,HloInstruction * new_producer)2368 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2369 absl::Span<HloInstruction* const> users, HloInstruction* new_producer) {
2370 for (HloInstruction* user : users) {
2371 TF_RETURN_IF_ERROR(ReplaceUseWithDifferentShape(user, new_producer));
2372 }
2373
2374 if (parent_ && parent_->root_instruction() == this) {
2375 parent_->set_root_instruction(new_producer,
2376 /*accept_different_shape=*/true);
2377 }
2378 return Status::OK();
2379 }
2380
ReplaceAllUsesWith(HloInstruction * new_producer)2381 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
2382 TF_RET_CHECK(
2383 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
2384 << shape() << " is not compatible with " << new_producer->shape();
2385 return ReplaceAllUsesWithDifferentShape(new_producer);
2386 }
2387
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)2388 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
2389 HloInstruction* new_producer) {
2390 bool new_producer_is_user = false;
2391 for (HloInstruction* user : users()) {
2392 if (user == new_producer) {
2393 // It's possible that new_producer is a user of this instruction as might
2394 // be the case when replacing an instruction with a kCopy of itself. In
2395 // this case, don't do the replacement to avoid creating a cycle in the
2396 // graph. new_producer remains the only user of this instruction.
2397 new_producer_is_user = true;
2398 } else {
2399 std::replace(user->operands_.begin(), user->operands_.end(), this,
2400 new_producer);
2401 new_producer->AddUser(user);
2402 if (user->opcode() == HloOpcode::kFusion) {
2403 TF_RETURN_IF_ERROR(
2404 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
2405 }
2406 }
2407 }
2408 users_.clear();
2409 user_map_.clear();
2410 if (new_producer_is_user) {
2411 AddUser(new_producer);
2412 }
2413 if (parent_ && parent_->root_instruction() == this) {
2414 parent_->set_root_instruction(new_producer,
2415 /*accept_different_shape=*/true);
2416 }
2417
2418 return Status::OK();
2419 }
2420
IsEffectiveBitcast() const2421 bool HloInstruction::IsEffectiveBitcast() const {
2422 return opcode_ == HloOpcode::kBitcast ||
2423 (opcode_ == HloOpcode::kTranspose &&
2424 ShapeUtil::TransposeIsBitcast(operand(0)->shape(), shape(),
2425 dimensions()));
2426 }
2427
to_apply() const2428 HloComputation* HloInstruction::to_apply() const {
2429 switch (opcode_) {
2430 case HloOpcode::kCall:
2431 case HloOpcode::kMap:
2432 case HloOpcode::kReduceWindow:
2433 case HloOpcode::kReduce:
2434 case HloOpcode::kAllReduce:
2435 case HloOpcode::kScatter:
2436 case HloOpcode::kSort:
2437 case HloOpcode::kCustomCall:
2438 CHECK_EQ(called_computations_.size(), 1);
2439 return called_computations_[0];
2440 default:
2441 LOG(FATAL) << "Invalid opcode for to_apply(): "
2442 << HloOpcodeString(opcode());
2443 }
2444 }
2445
set_to_apply(HloComputation * computation)2446 void HloInstruction::set_to_apply(HloComputation* computation) {
2447 // Don't allow changing the computation for fused instructions so we don't
2448 // have to recompute called_instructions for the entire fusion instruction.
2449 CHECK(!IsFused());
2450 switch (opcode_) {
2451 case HloOpcode::kCall:
2452 case HloOpcode::kMap:
2453 case HloOpcode::kReduceWindow:
2454 case HloOpcode::kReduce:
2455 case HloOpcode::kAllReduce:
2456 case HloOpcode::kScatter:
2457 case HloOpcode::kSort:
2458 case HloOpcode::kCustomCall:
2459 CHECK_EQ(called_computations_.size(), 1);
2460 called_computations_[0] = computation;
2461 break;
2462 default:
2463 LOG(FATAL) << "Invalid opcode for to_apply(): "
2464 << HloOpcodeString(opcode());
2465 }
2466 }
2467
while_condition() const2468 HloComputation* HloInstruction::while_condition() const {
2469 CHECK_EQ(HloOpcode::kWhile, opcode_);
2470 return called_computations_[kConditionComputationIndex];
2471 }
2472
while_body() const2473 HloComputation* HloInstruction::while_body() const {
2474 CHECK_EQ(HloOpcode::kWhile, opcode_);
2475 return called_computations_[kBodyComputationIndex];
2476 }
2477
set_while_condition(HloComputation * computation)2478 void HloInstruction::set_while_condition(HloComputation* computation) {
2479 // Don't allow changing the computation for fused instructions so we don't
2480 // have to recompute called_instructions for the entire fusion instruction.
2481 CHECK(!IsFused());
2482 CHECK_EQ(HloOpcode::kWhile, opcode_);
2483 called_computations_[kConditionComputationIndex] = computation;
2484 }
2485
set_while_body(HloComputation * computation)2486 void HloInstruction::set_while_body(HloComputation* computation) {
2487 // Don't allow changing the computation for fused instructions so we don't
2488 // have to recompute called_instructions for the entire fusion instruction.
2489 CHECK(!IsFused());
2490 CHECK_EQ(HloOpcode::kWhile, opcode_);
2491 called_computations_[kBodyComputationIndex] = computation;
2492 }
2493
while_init() const2494 HloInstruction* HloInstruction::while_init() const {
2495 CHECK_EQ(HloOpcode::kWhile, opcode_);
2496 return operands_[0];
2497 }
2498
true_computation() const2499 HloComputation* HloInstruction::true_computation() const {
2500 CHECK_EQ(HloOpcode::kConditional, opcode_);
2501 CHECK_EQ(PRED, operand(0)->shape().element_type());
2502 return called_computations_[kTrueComputationIndex];
2503 }
2504
false_computation() const2505 HloComputation* HloInstruction::false_computation() const {
2506 CHECK_EQ(HloOpcode::kConditional, opcode_);
2507 CHECK_EQ(PRED, operand(0)->shape().element_type());
2508 return called_computations_[kFalseComputationIndex];
2509 }
2510
branch_computations() const2511 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2512 const {
2513 CHECK(HloOpcode::kConditional == opcode_);
2514 return called_computations_;
2515 }
2516
branch_count() const2517 int HloInstruction::branch_count() const {
2518 CHECK(HloOpcode::kConditional == opcode_);
2519 return called_computations_.size();
2520 }
2521
branch_computation(int b) const2522 HloComputation* HloInstruction::branch_computation(int b) const {
2523 CHECK(HloOpcode::kConditional == opcode_);
2524 CHECK_GE(b, 0);
2525 CHECK_LT(b, called_computations_.size());
2526 return called_computations_[b];
2527 }
2528
set_branch_computation(int b,HloComputation * computation)2529 void HloInstruction::set_branch_computation(int b,
2530 HloComputation* computation) {
2531 // Don't allow changing the computation for fused instructions so we don't
2532 // have to recompute called_instructions for the entire fusion instruction.
2533 CHECK(!IsFused());
2534 CHECK_EQ(HloOpcode::kConditional, opcode_);
2535 called_computations_[b] = computation;
2536 }
2537
SignatureString() const2538 string HloInstruction::SignatureString() const {
2539 string operands =
2540 StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2541 StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2542 });
2543 return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2544 }
2545
PrintName(const string & name,bool print_ids)2546 string PrintName(const string& name, bool print_ids) {
2547 if (print_ids) {
2548 return name;
2549 } else {
2550 auto dot_position = name.find_first_of('.');
2551 return name.substr(0, dot_position);
2552 }
2553 }
2554
2555 namespace {
2556
2557 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2558
PrintNameInternal(const string & name,const HloPrintOptions & options)2559 string PrintNameInternal(const string& name, const HloPrintOptions& options) {
2560 return StrCat(options.print_percent() ? "%" : "",
2561 PrintName(name, options.print_ids()));
2562 }
2563
PrintCycle(const HloInstruction * child,DFSStack * dfs_stack)2564 void PrintCycle(const HloInstruction* child, DFSStack* dfs_stack) {
2565 // This set contains HloInstructions from the top of `DFSStack` that might
2566 // belong to the cycle, i.e. if DFSStack :=[back,...,child,...,top], then
2567 // `subgraph` := {child,...,top}.
2568 absl::flat_hash_set<const HloInstruction*> subgraph;
2569 while (!dfs_stack->empty() && dfs_stack->back().second != child) {
2570 subgraph.insert(dfs_stack->back().second);
2571 dfs_stack->pop_back();
2572 }
2573 // Start dfs at `child` and find a cycle with all nodes in `subgraph`.
2574 absl::flat_hash_set<const HloInstruction*> visited;
2575 absl::InlinedVector<const HloInstruction*, 16> dfs;
2576 dfs.push_back(child);
2577 while (!dfs.empty()) {
2578 bool found_next_instr = false;
2579 for (const auto& user : dfs.back()->users()) {
2580 if (user == child) {
2581 dfs.push_back(child);
2582 LOG(INFO) << "\n\nDirected cycle:\n "
2583 << absl::StrJoin(
2584 dfs, "\n ",
2585 [](std::string* out, const HloInstruction* instr) {
2586 out->append(instr->name());
2587 });
2588 return;
2589 }
2590 if (!subgraph.contains(user) || visited.contains(user)) {
2591 continue;
2592 }
2593 visited.insert(user);
2594 dfs.push_back(user);
2595 found_next_instr = true;
2596 }
2597 if (!found_next_instr) {
2598 dfs.pop_back();
2599 }
2600 }
2601 }
2602
2603 } // namespace
2604
ToString(const HloPrintOptions & options) const2605 string HloInstruction::ToString(const HloPrintOptions& options) const {
2606 CanonicalNameMap new_map;
2607 return ToStringWithCanonicalNameMap(options, &new_map);
2608 }
2609
IsOpElementwise(HloOpcode opcode)2610 bool HloInstruction::IsOpElementwise(HloOpcode opcode) {
2611 switch (opcode) {
2612 // Unary elementwise operations.
2613 case HloOpcode::kAbs:
2614 case HloOpcode::kRoundNearestAfz:
2615 case HloOpcode::kCeil:
2616 case HloOpcode::kClz:
2617 case HloOpcode::kConvert:
2618 case HloOpcode::kBitcastConvert:
2619 case HloOpcode::kCopy:
2620 case HloOpcode::kCos:
2621 case HloOpcode::kExp:
2622 case HloOpcode::kExpm1:
2623 case HloOpcode::kFloor:
2624 case HloOpcode::kImag:
2625 case HloOpcode::kIsFinite:
2626 case HloOpcode::kLog:
2627 case HloOpcode::kLog1p:
2628 case HloOpcode::kNot:
2629 case HloOpcode::kNegate:
2630 case HloOpcode::kPopulationCount:
2631 case HloOpcode::kReal:
2632 case HloOpcode::kReducePrecision:
2633 case HloOpcode::kRsqrt:
2634 case HloOpcode::kLogistic:
2635 case HloOpcode::kSign:
2636 case HloOpcode::kSin:
2637 case HloOpcode::kSqrt:
2638 case HloOpcode::kCbrt:
2639 case HloOpcode::kTanh:
2640 return true;
2641
2642 // Binary elementwise operations, the same as in IsElementwiseBinary().
2643 case HloOpcode::kAdd:
2644 case HloOpcode::kAtan2:
2645 case HloOpcode::kCompare:
2646 case HloOpcode::kComplex:
2647 case HloOpcode::kDivide:
2648 case HloOpcode::kMaximum:
2649 case HloOpcode::kMinimum:
2650 case HloOpcode::kMultiply:
2651 case HloOpcode::kPower:
2652 case HloOpcode::kRemainder:
2653 case HloOpcode::kSubtract:
2654 case HloOpcode::kAnd:
2655 case HloOpcode::kOr:
2656 case HloOpcode::kXor:
2657 case HloOpcode::kShiftLeft:
2658 case HloOpcode::kShiftRightArithmetic:
2659 case HloOpcode::kShiftRightLogical:
2660 return true;
2661
2662 // Ternary elementwise operations.
2663 case HloOpcode::kSelect:
2664 case HloOpcode::kClamp:
2665 return true;
2666
2667 default:
2668 return false;
2669 }
2670 }
2671
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2672 bool HloInstruction::IsElementwiseImpl(
2673 const absl::optional<int64>& operand_idx) const {
2674 if (opcode_ == HloOpcode::kDynamicUpdateSlice) {
2675 return operand_idx.has_value() && operand_idx.value() == 0;
2676 }
2677 return IsOpElementwise(opcode_);
2678 }
2679
IsCrossModuleAllReduce() const2680 bool HloInstruction::IsCrossModuleAllReduce() const {
2681 return opcode() == HloOpcode::kAllReduce && channel_id();
2682 }
2683
IsCrossReplicaAllReduce() const2684 bool HloInstruction::IsCrossReplicaAllReduce() const {
2685 return opcode() == HloOpcode::kAllReduce && !channel_id();
2686 }
2687
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2688 string HloInstruction::ToStringWithCanonicalNameMap(
2689 const HloPrintOptions& options,
2690 CanonicalNameMap* canonical_name_map) const {
2691 string result = "";
2692
2693 // Logic to print the instruction name (e.g. "%foo = ").
2694 if (options.canonicalize_instruction_names()) {
2695 if (options.is_in_nested_computation()) {
2696 // If we are canonicalizing instruction names and this is a top-level
2697 // HloInstruction::ToString() call, don't print an instruction name.
2698 StrAppend(&result,
2699 PrintNameInternal(canonical_name_map->LookupOrInsert(name()),
2700 options),
2701 " = ");
2702 }
2703 } else {
2704 StrAppend(&result, PrintNameInternal(name(), options), " = ");
2705 }
2706
2707 if (options.print_result_shape()) {
2708 // Print shape.
2709 if (options.include_layout_in_shapes()) {
2710 StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ");
2711 } else {
2712 StrAppend(&result, ShapeUtil::HumanString(shape()), " ");
2713 }
2714 }
2715
2716 // Print opcode, operand(s).
2717 StrAppend(&result, HloOpcodeString(opcode()), "(",
2718 OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2719 ")");
2720
2721 // Print additional attributes. If an instruction contains a subcomputation,
2722 // the subcomputation is also printed here.
2723 for (const string& extra : ExtraAttributesToString(options)) {
2724 StrAppend(&result, ", ", extra);
2725 }
2726
2727 if (options.print_metadata() &&
2728 (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2729 !metadata_.source_file().empty())) {
2730 StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2731 }
2732 if (options.print_backend_config() && !backend_config_.empty()) {
2733 StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2734 }
2735 return result;
2736 }
2737
OperandsToString(const HloPrintOptions & options) const2738 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2739 CanonicalNameMap new_map;
2740 return OperandsToStringWithCanonicalNameMap(options, &new_map);
2741 }
2742
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2743 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2744 const HloPrintOptions& options,
2745 CanonicalNameMap* canonical_name_map) const {
2746 string operands;
2747 absl::Span<HloInstruction* const> slice(operands_);
2748 const int64 kMaxOperandsToShowIfCompact = 4;
2749 if (options.compact_operands() &&
2750 slice.size() > kMaxOperandsToShowIfCompact) {
2751 slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2752 }
2753 operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
2754 // If operand is already been deleted, put `null` to the string output.
2755 if (operand == nullptr) {
2756 StrAppend(out, "null ");
2757 return;
2758 }
2759 std::vector<string> str;
2760 if (options.print_operand_shape()) {
2761 if (options.include_layout_in_shapes()) {
2762 str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2763 } else {
2764 str.push_back(ShapeUtil::HumanString(operand->shape()));
2765 }
2766 }
2767
2768 // In a top-level HloInstruction::ToString() call, the operand name is not
2769 // part of the canonical string.
2770 if (options.canonicalize_instruction_names() &&
2771 options.is_in_nested_computation()) {
2772 str.push_back(PrintNameInternal(
2773 canonical_name_map->LookupOrInsert(operand->name()), options));
2774 } else if (options.print_operand_names()) {
2775 str.push_back(PrintNameInternal(operand->name(), options));
2776 }
2777 StrAppend(out, StrJoin(str, " "));
2778 });
2779 const int64 remaining = operands_.size() - slice.size();
2780 if (slice.size() != operands_.size()) {
2781 StrAppend(&operands, ", ...(+", remaining, ")");
2782 }
2783 return operands;
2784 }
2785
ExtraAttributesToString(const HloPrintOptions & options) const2786 std::vector<string> HloInstruction::ExtraAttributesToString(
2787 const HloPrintOptions& options) const {
2788 std::vector<string> extra = options.print_extra_attributes()
2789 ? ExtraAttributesToStringImpl(options)
2790 : std::vector<string>();
2791
2792 if (options.print_subcomputation_mode() ==
2793 HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
2794 if (opcode() == HloOpcode::kWhile) {
2795 extra.push_back(StrCat(
2796 "condition=", PrintNameInternal(while_condition()->name(), options)));
2797 extra.push_back(
2798 StrCat("body=", PrintNameInternal(while_body()->name(), options)));
2799 } else if (opcode() == HloOpcode::kSelectAndScatter) {
2800 extra.push_back(
2801 StrCat("select=", PrintNameInternal(select()->name(), options)));
2802 extra.push_back(
2803 StrCat("scatter=", PrintNameInternal(scatter()->name(), options)));
2804 } else if (opcode() == HloOpcode::kConditional) {
2805 if (operand(0)->shape().element_type() == PRED) {
2806 extra.push_back(
2807 StrCat("true_computation=",
2808 PrintNameInternal(true_computation()->name(), options)));
2809 extra.push_back(
2810 StrCat("false_computation=",
2811 PrintNameInternal(false_computation()->name(), options)));
2812 } else {
2813 extra.push_back(StrCat(
2814 "branch_computations={",
2815 StrJoin(branch_computations(), ", ",
2816 [&](string* out, const HloComputation* computation) {
2817 StrAppend(
2818 out, PrintNameInternal(computation->name(), options));
2819 }),
2820 "}"));
2821 }
2822 } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2823 opcode() == HloOpcode::kReduceWindow ||
2824 opcode() == HloOpcode::kReduce ||
2825 opcode() == HloOpcode::kAllReduce ||
2826 opcode() == HloOpcode::kScatter ||
2827 opcode() == HloOpcode::kSort) {
2828 extra.push_back(
2829 StrCat("to_apply=", PrintNameInternal(to_apply()->name(), options)));
2830 } else if (opcode() == HloOpcode::kCustomCall) {
2831 if (!called_computations().empty()) {
2832 extra.push_back(StrCat(
2833 "called_computations={",
2834 StrJoin(called_computations(), ", ",
2835 [&](string* out, const HloComputation* computation) {
2836 StrAppend(
2837 out, PrintNameInternal(computation->name(), options));
2838 }),
2839 "}"));
2840 }
2841 } else if (!called_computations().empty()) {
2842 extra.push_back(StrCat(
2843 "calls=",
2844 StrJoin(called_computations(), ", ",
2845 [&](string* out, const HloComputation* computation) {
2846 StrAppend(out,
2847 PrintNameInternal(computation->name(), options));
2848 })));
2849 }
2850 } else if (options.print_subcomputation_mode() ==
2851 HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
2852 HloPrintOptions new_options = options;
2853 new_options.set_is_in_nested_computation(true);
2854 switch (opcode()) {
2855 case HloOpcode::kWhile:
2856 extra.push_back(
2857 StrCat("condition=\n", while_condition()->ToString(new_options)));
2858 extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
2859 break;
2860 case HloOpcode::kSelectAndScatter:
2861 extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
2862 extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
2863 break;
2864 case HloOpcode::kConditional:
2865 if (operand(0)->shape().element_type() == PRED) {
2866 extra.push_back(StrCat("true_computation=\n",
2867 true_computation()->ToString(new_options)));
2868 extra.push_back(StrCat("false_computation=\n",
2869 false_computation()->ToString(new_options)));
2870 } else {
2871 extra.push_back(StrCat(
2872 "branch_computations={\n",
2873 StrJoin(branch_computations(), ",\n",
2874 [&](string* out, const HloComputation* computation) {
2875 StrAppend(out, computation->ToString(new_options));
2876 }),
2877 "\n}"));
2878 }
2879 break;
2880 case HloOpcode::kCall:
2881 case HloOpcode::kMap:
2882 case HloOpcode::kReduceWindow:
2883 case HloOpcode::kReduce:
2884 case HloOpcode::kAllReduce:
2885 case HloOpcode::kScatter:
2886 case HloOpcode::kSort:
2887 extra.push_back(
2888 StrCat("to_apply=\n", to_apply()->ToString(new_options)));
2889 break;
2890 default:
2891 if (!called_computations().empty()) {
2892 extra.push_back(StrCat(
2893 "calls=\n",
2894 StrJoin(called_computations(), ", ",
2895 [&](string* out, const HloComputation* computation) {
2896 StrAppend(out, computation->ToString(new_options));
2897 })));
2898 }
2899 break;
2900 }
2901 }
2902
2903 if (has_sharding()) {
2904 extra.push_back(
2905 StrCat("sharding=", sharding().ToString(options.print_metadata())));
2906 }
2907 if (!frontend_attributes_.map().empty()) {
2908 extra.push_back(StrCat("frontend_attributes=",
2909 FrontendAttributesToString(frontend_attributes_)));
2910 }
2911 if (!outer_dimension_partitions_.empty()) {
2912 extra.push_back(absl::StrFormat("outer_dimension_partitions={%s}",
2913 StrJoin(outer_dimension_partitions_, ",")));
2914 }
2915
2916 if (options.print_control_dependencies() && !control_predecessors_.empty()) {
2917 extra.push_back(StrCat("control-predecessors={",
2918 StrJoin(control_predecessors_, ", ",
2919 [&](string* out, HloInstruction* pre) {
2920 StrAppend(out, PrintNameInternal(
2921 pre->name(), options));
2922 }),
2923 "}"));
2924 }
2925
2926 return extra;
2927 }
2928
ToShortString() const2929 string HloInstruction::ToShortString() const {
2930 return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2931 StrJoin(operands_, ", ",
2932 [](string* out, HloInstruction* operand) {
2933 StrAppend(out, "%", operand->name());
2934 }),
2935 ")");
2936 }
2937
ToProto() const2938 HloInstructionProto HloInstruction::ToProto() const {
2939 HloInstructionProto proto;
2940 CHECK(unique_id_ != -1)
2941 << "This instruction does not have a valid id. Please make sure the "
2942 "instruction is inside a module before dumping it.";
2943 proto.set_id(unique_id_);
2944 proto.set_name(name_);
2945 proto.set_opcode(HloOpcodeString(opcode_));
2946 *proto.mutable_shape() = shape_.ToProto();
2947 for (const HloInstruction* operand : operands_) {
2948 proto.add_operand_ids(operand->unique_id());
2949 }
2950 for (const HloInstruction* control : control_predecessors_) {
2951 proto.add_control_predecessor_ids(control->unique_id());
2952 }
2953
2954 *proto.mutable_metadata() = metadata_;
2955 proto.set_backend_config(backend_config_);
2956 if (opcode() != HloOpcode::kFusion) {
2957 for (const HloComputation* computation : called_computations_) {
2958 proto.add_called_computation_ids(computation->unique_id());
2959 }
2960 }
2961
2962 if (has_sharding()) {
2963 *proto.mutable_sharding() = sharding().ToProto();
2964 }
2965 if (!outer_dimension_partitions_.empty()) {
2966 for (const auto& idx : outer_dimension_partitions_) {
2967 proto.mutable_outer_dimension_partitions()->Add(idx);
2968 }
2969 }
2970
2971 *proto.mutable_frontend_attributes() = frontend_attributes_;
2972
2973 return proto;
2974 }
2975
ToCategory() const2976 string HloInstruction::ToCategory() const {
2977 if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2978 opcode() == HloOpcode::kReshape ||
2979 opcode() == HloOpcode::kDynamicReshape) {
2980 return "data formatting";
2981 }
2982
2983 if (IsElementwise()) {
2984 return "non-fusion elementwise";
2985 }
2986
2987 return HloOpcodeString(opcode());
2988 }
2989
tracing() const2990 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2991
set_tracing(HloInstruction * trace_instruction)2992 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2993 trace_instruction_ = trace_instruction;
2994 }
2995
IsFused() const2996 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2997
IsCustomCall(absl::string_view target) const2998 bool HloInstruction::IsCustomCall(absl::string_view target) const {
2999 return opcode() == HloOpcode::kCustomCall && custom_call_target() == target;
3000 }
3001
IsInputFusion() const3002 bool HloInstruction::IsInputFusion() const {
3003 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kInput;
3004 }
3005
IsLoopFusion() const3006 bool HloInstruction::IsLoopFusion() const {
3007 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kLoop;
3008 }
3009
IsOutputFusion() const3010 bool HloInstruction::IsOutputFusion() const {
3011 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kOutput;
3012 }
3013
IsCustomFusion() const3014 bool HloInstruction::IsCustomFusion() const {
3015 return opcode() == HloOpcode::kFusion && fusion_kind() == FusionKind::kCustom;
3016 }
3017
IsFusible() const3018 bool HloInstruction::IsFusible() const {
3019 // Instructions which are traced should not be fused.
3020 if (tracing()) {
3021 return false;
3022 }
3023 // Some kinds of instructions don't make sense to fuse.
3024 switch (opcode_) {
3025 case HloOpcode::kDomain:
3026 case HloOpcode::kParameter:
3027 case HloOpcode::kWhile:
3028 case HloOpcode::kConditional:
3029 case HloOpcode::kCall:
3030 return false;
3031 // Fusions are always fusible.
3032 case HloOpcode::kFusion:
3033 // Side effecting reduce and reduce window would be invalid HLO.
3034 case HloOpcode::kMap:
3035 case HloOpcode::kReduce:
3036 case HloOpcode::kReduceWindow:
3037 return true;
3038 case HloOpcode::kRng:
3039 return user_count() <= 1;
3040 // Side effecting instructions cannot be fused.
3041 default:
3042 return !HasSideEffect();
3043 }
3044 }
3045
HloInstruction(HloOpcode opcode,const Shape & shape)3046 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
3047 : unique_id_(-1),
3048 opcode_(opcode),
3049 shape_(shape),
3050 name_(HloOpcodeString(opcode)),
3051 marked_as_dead_(false) {
3052 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
3053 }
3054
3055 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)3056 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
3057 switch (opcode_) {
3058 case HloOpcode::kAbs:
3059 return visitor->HandleAbs(this);
3060 case HloOpcode::kAtan2:
3061 return visitor->HandleAtan2(this);
3062 case HloOpcode::kRoundNearestAfz:
3063 return visitor->HandleRound(this);
3064 case HloOpcode::kBatchNormTraining:
3065 return visitor->HandleBatchNormTraining(this);
3066 case HloOpcode::kBatchNormInference:
3067 return visitor->HandleBatchNormInference(this);
3068 case HloOpcode::kBatchNormGrad:
3069 return visitor->HandleBatchNormGrad(this);
3070 case HloOpcode::kLogistic:
3071 return visitor->HandleLogistic(this);
3072 case HloOpcode::kSign:
3073 return visitor->HandleSign(this);
3074 case HloOpcode::kConstant:
3075 return visitor->HandleConstant(this);
3076 case HloOpcode::kGetTupleElement:
3077 return visitor->HandleGetTupleElement(this);
3078 case HloOpcode::kParameter:
3079 return visitor->HandleParameter(this);
3080 case HloOpcode::kCompare:
3081 return visitor->HandleCompare(this);
3082 case HloOpcode::kComplex:
3083 return visitor->HandleComplex(this);
3084 case HloOpcode::kAdd:
3085 return visitor->HandleAdd(this);
3086 case HloOpcode::kDivide:
3087 return visitor->HandleDivide(this);
3088 case HloOpcode::kSubtract:
3089 return visitor->HandleSubtract(this);
3090 case HloOpcode::kMaximum:
3091 return visitor->HandleMaximum(this);
3092 case HloOpcode::kMinimum:
3093 return visitor->HandleMinimum(this);
3094 case HloOpcode::kAnd:
3095 return visitor->HandleAnd(this);
3096 case HloOpcode::kOr:
3097 return visitor->HandleOr(this);
3098 case HloOpcode::kXor:
3099 return visitor->HandleXor(this);
3100 case HloOpcode::kShiftLeft:
3101 return visitor->HandleShiftLeft(this);
3102 case HloOpcode::kShiftRightArithmetic:
3103 return visitor->HandleShiftRightArithmetic(this);
3104 case HloOpcode::kShiftRightLogical:
3105 return visitor->HandleShiftRightLogical(this);
3106 case HloOpcode::kConcatenate:
3107 return visitor->HandleConcatenate(this);
3108 case HloOpcode::kConvert:
3109 return visitor->HandleConvert(this);
3110 case HloOpcode::kBitcastConvert:
3111 return visitor->HandleBitcastConvert(this);
3112 case HloOpcode::kCopy:
3113 return visitor->HandleCopy(this);
3114 case HloOpcode::kMultiply:
3115 return visitor->HandleMultiply(this);
3116 case HloOpcode::kDot:
3117 return visitor->HandleDot(this);
3118 case HloOpcode::kPower:
3119 return visitor->HandlePower(this);
3120 case HloOpcode::kRemainder:
3121 return visitor->HandleRemainder(this);
3122 case HloOpcode::kSelect:
3123 return visitor->HandleSelect(this);
3124 case HloOpcode::kTupleSelect:
3125 return visitor->HandleTupleSelect(this);
3126 case HloOpcode::kConvolution:
3127 return visitor->HandleConvolution(this);
3128 case HloOpcode::kFft:
3129 return visitor->HandleFft(this);
3130 case HloOpcode::kAllGather:
3131 return visitor->HandleAllGather(this);
3132 case HloOpcode::kAllReduce:
3133 return visitor->HandleAllReduce(this);
3134 case HloOpcode::kAllToAll:
3135 return visitor->HandleAllToAll(this);
3136 case HloOpcode::kCollectivePermute:
3137 return visitor->HandleCollectivePermute(this);
3138 case HloOpcode::kCollectivePermuteStart:
3139 return visitor->HandleCollectivePermuteStart(this);
3140 case HloOpcode::kCollectivePermuteDone:
3141 return visitor->HandleCollectivePermuteDone(this);
3142 case HloOpcode::kReplicaId:
3143 return visitor->HandleReplicaId(this);
3144 case HloOpcode::kPartitionId:
3145 return visitor->HandlePartitionId(this);
3146 case HloOpcode::kTuple:
3147 return visitor->HandleTuple(this);
3148 case HloOpcode::kMap:
3149 return visitor->HandleMap(this);
3150 case HloOpcode::kClamp:
3151 return visitor->HandleClamp(this);
3152 case HloOpcode::kReduce:
3153 return visitor->HandleReduce(this);
3154 case HloOpcode::kReduceWindow:
3155 return visitor->HandleReduceWindow(this);
3156 case HloOpcode::kSelectAndScatter:
3157 return visitor->HandleSelectAndScatter(this);
3158 case HloOpcode::kNegate:
3159 return visitor->HandleNegate(this);
3160 case HloOpcode::kExp:
3161 return visitor->HandleExp(this);
3162 case HloOpcode::kExpm1:
3163 return visitor->HandleExpm1(this);
3164 case HloOpcode::kFloor:
3165 return visitor->HandleFloor(this);
3166 case HloOpcode::kCeil:
3167 return visitor->HandleCeil(this);
3168 case HloOpcode::kClz:
3169 return visitor->HandleClz(this);
3170 case HloOpcode::kLog:
3171 return visitor->HandleLog(this);
3172 case HloOpcode::kLog1p:
3173 return visitor->HandleLog1p(this);
3174 case HloOpcode::kTanh:
3175 return visitor->HandleTanh(this);
3176 case HloOpcode::kCos:
3177 return visitor->HandleCos(this);
3178 case HloOpcode::kSin:
3179 return visitor->HandleSin(this);
3180 case HloOpcode::kSqrt:
3181 return visitor->HandleSqrt(this);
3182 case HloOpcode::kCbrt:
3183 return visitor->HandleCbrt(this);
3184 case HloOpcode::kRsqrt:
3185 return visitor->HandleRsqrt(this);
3186 case HloOpcode::kReal:
3187 return visitor->HandleReal(this);
3188 case HloOpcode::kImag:
3189 return visitor->HandleImag(this);
3190 case HloOpcode::kIsFinite:
3191 return visitor->HandleIsFinite(this);
3192 case HloOpcode::kNot:
3193 return visitor->HandleNot(this);
3194 case HloOpcode::kPopulationCount:
3195 return visitor->HandlePopulationCount(this);
3196 case HloOpcode::kBitcast:
3197 return visitor->HandleBitcast(this);
3198 case HloOpcode::kBroadcast:
3199 return visitor->HandleBroadcast(this);
3200 case HloOpcode::kPad:
3201 return visitor->HandlePad(this);
3202 case HloOpcode::kReshape:
3203 return visitor->HandleReshape(this);
3204 case HloOpcode::kDynamicReshape:
3205 return visitor->HandleDynamicReshape(this);
3206 case HloOpcode::kTranspose:
3207 return visitor->HandleTranspose(this);
3208 case HloOpcode::kReverse:
3209 return visitor->HandleReverse(this);
3210 case HloOpcode::kReducePrecision:
3211 return visitor->HandleReducePrecision(this);
3212 case HloOpcode::kSlice:
3213 return visitor->HandleSlice(this);
3214 case HloOpcode::kDynamicSlice:
3215 return visitor->HandleDynamicSlice(this);
3216 case HloOpcode::kDynamicUpdateSlice:
3217 return visitor->HandleDynamicUpdateSlice(this);
3218 case HloOpcode::kSort:
3219 return visitor->HandleSort(this);
3220 case HloOpcode::kInfeed:
3221 return visitor->HandleInfeed(this);
3222 case HloOpcode::kOutfeed:
3223 return visitor->HandleOutfeed(this);
3224 case HloOpcode::kRng:
3225 return visitor->HandleRng(this);
3226 case HloOpcode::kRngBitGenerator:
3227 return visitor->HandleRngBitGenerator(this);
3228 case HloOpcode::kRngGetAndUpdateState:
3229 return visitor->HandleRngGetAndUpdateState(this);
3230 case HloOpcode::kWhile:
3231 return visitor->HandleWhile(this);
3232 case HloOpcode::kFusion:
3233 return visitor->HandleFusion(this);
3234 case HloOpcode::kCall:
3235 return visitor->HandleCall(this);
3236 case HloOpcode::kConditional:
3237 return visitor->HandleConditional(this);
3238 case HloOpcode::kCustomCall:
3239 return visitor->HandleCustomCall(this);
3240 case HloOpcode::kCopyStart:
3241 return visitor->HandleCopyStart(this);
3242 case HloOpcode::kCopyDone:
3243 return visitor->HandleCopyDone(this);
3244 case HloOpcode::kRecv:
3245 return visitor->HandleRecv(this);
3246 case HloOpcode::kRecvDone:
3247 return visitor->HandleRecvDone(this);
3248 case HloOpcode::kSend:
3249 return visitor->HandleSend(this);
3250 case HloOpcode::kSendDone:
3251 return visitor->HandleSendDone(this);
3252 case HloOpcode::kGather:
3253 return visitor->HandleGather(this);
3254 case HloOpcode::kScatter:
3255 return visitor->HandleScatter(this);
3256 case HloOpcode::kDomain:
3257 return visitor->HandleDomain(this);
3258 case HloOpcode::kAfterAll:
3259 return visitor->HandleAfterAll(this);
3260 case HloOpcode::kAddDependency:
3261 return visitor->HandleAddDependency(this);
3262 case HloOpcode::kIota:
3263 return visitor->HandleIota(this);
3264 case HloOpcode::kGetDimensionSize:
3265 return visitor->HandleGetDimensionSize(this);
3266 case HloOpcode::kSetDimensionSize:
3267 return visitor->HandleSetDimensionSize(this);
3268 case HloOpcode::kTriangularSolve:
3269 return visitor->HandleTriangularSolve(this);
3270 case HloOpcode::kCholesky:
3271 return visitor->HandleCholesky(this);
3272
3273 // These opcodes are not handled here.
3274 case HloOpcode::kTrace:
3275 return Status::OK();
3276 }
3277 return InternalError(
3278 "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
3279 "please file a bug for XLA.",
3280 HloOpcodeString(opcode_));
3281 }
3282
3283 // Explicit instantiations.
3284 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
3285 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
3286
3287 // Push "child" onto the dfs_stack if not already visited. Returns false if a
3288 // cycle was detected, and true otherwise.
3289 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)3290 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
3291 HloInstruction* child) {
3292 CHECK(child != nullptr);
3293 const int id = child->unique_id();
3294 CHECK_GE(id, 0) << "instruction may not have a parent computation";
3295 switch (visitor->GetVisitState(id)) {
3296 case Visitor::kVisiting:
3297 return false;
3298
3299 case Visitor::kVisited:
3300 // Nothing to do
3301 return true;
3302
3303 case Visitor::kNotVisited:
3304 dfs_stack->push_back(std::make_pair(id, child));
3305 return true;
3306 }
3307 }
3308
3309 using InternalCompareFunction =
3310 std::function<bool(std::pair<int, const HloInstruction*>,
3311 std::pair<int, const HloInstruction*>)>;
3312 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)3313 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
3314 const InternalCompareFunction* operand_order,
3315 bool ignore_control_predecessors) {
3316 // Calculating the instruction count within a module can be expensive on large
3317 // models so only do it if the visit state is empty. This will help when the
3318 // same visitor is reused across many computations of a single module.
3319 if (visitor->VisitStateCapacity() == 0) {
3320 visitor->ReserveVisitStates(root->GetModule()->instruction_count());
3321 }
3322
3323 // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
3324 //
3325 // We need to keep track of both the id and the instruction because
3326 // instructions can get deleted while they are on the stack, so we
3327 // can't always use the (potentially dead) instruction object to grab
3328 // its id.
3329 DFSStack dfs_stack;
3330 dfs_stack.emplace_back(root->unique_id(), root);
3331
3332 do {
3333 DCHECK(!dfs_stack.empty());
3334
3335 int current_id = dfs_stack.back().first;
3336 HloInstruction* current_node = dfs_stack.back().second;
3337 CHECK_GE(current_id, 0) << current_id << ": " << current_node
3338 << ": instruction may not have parent computation";
3339 typename Visitor::VisitState visit_state =
3340 visitor->GetVisitState(current_id);
3341 if (visit_state == Visitor::kVisited) {
3342 dfs_stack.pop_back();
3343 VLOG(3) << "Not visiting HLO (id = " << current_id
3344 << ") as it was already visited.";
3345 continue;
3346 }
3347
3348 if (visit_state == Visitor::kVisiting) {
3349 dfs_stack.pop_back();
3350
3351 TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
3352 VLOG(2) << "Visiting HLO %" << current_node->name();
3353 TF_RETURN_IF_ERROR(current_node->Visit(visitor));
3354 visitor->SetVisitState(current_id, Visitor::kVisited);
3355 TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
3356 continue;
3357 }
3358
3359 visitor->SetVisitState(current_id, Visitor::kVisiting);
3360
3361 const size_t old_dfs_stack_size = dfs_stack.size();
3362 for (HloInstruction* child : current_node->operands()) {
3363 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3364 PrintCycle(child, &dfs_stack);
3365 return FailedPrecondition(
3366 "A cycle is detected while visiting instruction %s",
3367 current_node->ToString());
3368 }
3369 }
3370
3371 if (!ignore_control_predecessors) {
3372 for (HloInstruction* child : current_node->control_predecessors()) {
3373 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
3374 PrintCycle(child, &dfs_stack);
3375 return FailedPrecondition(
3376 "A cycle is detected while visiting instruction %s",
3377 current_node->ToString());
3378 }
3379 }
3380 }
3381
3382 if (operand_order != nullptr) {
3383 std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
3384 *operand_order);
3385 }
3386
3387 // This makes the traversal order the same as what you'd expect
3388 // out of a recursive algorithm.
3389 std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
3390 } while (!dfs_stack.empty());
3391
3392 return Status::OK();
3393 }
3394
3395 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)3396 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
3397 bool call_finish_visit,
3398 bool ignore_control_predecessors) {
3399 VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
3400 TF_RETURN_IF_ERROR(
3401 PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
3402 if (call_finish_visit) {
3403 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3404 }
3405 return Status::OK();
3406 }
3407
3408 // Explicit instantiations.
3409 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
3410 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
3411
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)3412 Status HloInstruction::AcceptWithOperandOrder(
3413 DfsHloVisitor* visitor, const CompareFunction& operand_order,
3414 bool call_finish_visit) {
3415 VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
3416 InternalCompareFunction func = [&operand_order](
3417 std::pair<int, const HloInstruction*> a,
3418 std::pair<int, const HloInstruction*> b) {
3419 // Call the client's comparison function on the actual HloInstruction*
3420 // objects (ignoring the internal ids we also have in our stack entries)
3421 return operand_order(a.second, b.second);
3422 };
3423 TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
3424 /*ignore_control_predecessors=*/false));
3425 if (call_finish_visit) {
3426 VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
3427 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
3428 VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
3429 }
3430 VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
3431 return Status::OK();
3432 }
3433
shape() const3434 const Shape& HloInstruction::shape() const { return shape_; }
3435
OperandIndices(const HloInstruction * operand) const3436 absl::InlinedVector<int64, 4> HloInstruction::OperandIndices(
3437 const HloInstruction* operand) const {
3438 absl::InlinedVector<int64, 4> result;
3439 for (int64 i = 0; i < operand_count(); ++i) {
3440 if (this->operand(i) == operand) {
3441 result.push_back(i);
3442 }
3443 }
3444 return result;
3445 }
3446
IsElementwiseBinary() const3447 bool HloInstruction::IsElementwiseBinary() const {
3448 return IsElementwise() && operand_count() == 2;
3449 }
3450
IsElementwise() const3451 bool HloInstruction::IsElementwise() const {
3452 return IsElementwiseImpl(absl::nullopt);
3453 }
3454
IsElementwiseOnOperand(int64 operand_idx) const3455 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
3456 return IsElementwiseImpl(operand_idx);
3457 }
3458
3459 // A helper class for memoized, recursive computation of HloOpcode::kFusion
3460 // in HloInstruction::OperandElementUse below.
3461 class HloInstruction::FusionReusesParamElements {
3462 public:
3463 using UseKind = HloInstruction::UseKind;
3464
3465 // We could rather iterate backwards through fused_instructions_ here, as it
3466 // is in reverse postorder, and compute whether each fused instruction reuses
3467 // the value of this parameter, which would save stack space but not allow us
3468 // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)3469 static UseKind Compute(int64 i, const HloInstruction& hlo) {
3470 absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
3471 return ComputeInternal(i, hlo, &memoization_cache);
3472 }
3473
3474 private:
ComputeInternal(int64 i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)3475 static UseKind ComputeInternal(
3476 int64 i, const HloInstruction& hlo,
3477 absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
3478 if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
3479 if (hlo_param->parameter_number() == i) {
3480 return UseKind::kUse;
3481 }
3482 }
3483
3484 auto p = cache->emplace(&hlo, UseKind::kNoUse);
3485 auto value_it = p.first;
3486 const bool key_is_new = p.second;
3487
3488 if (key_is_new) {
3489 for (int64 j = 0; j < hlo.operands_.size(); ++j) {
3490 UseKind old_val = value_it->second;
3491
3492 // The next operation invalidates iterators.
3493 UseKind new_val =
3494 Fold(old_val,
3495 FoldUseMandatory(hlo.OperandElementUse(j),
3496 ComputeInternal(i, *hlo.operand(j), cache)));
3497
3498 // Re-acquire the iterator. We could work harder to do this only if
3499 // absolutely necessary, but this code is not hot enough to warrant
3500 // that.
3501 value_it = cache->find(&hlo);
3502 value_it->second = new_val;
3503 // Fold() minimizes the UseKind value. If it is already minimum, we can
3504 // break the loop early.
3505 if (new_val == UseKind::kReuse) {
3506 break;
3507 }
3508 }
3509 }
3510 return value_it->second;
3511 }
3512
3513 // Combines two UseKinds.
3514 //
3515 // This is the min operation on the lattice
3516 //
3517 // kReuse < kUse < kNoUse.
3518 //
3519 // Two kUses uses which have different permutations count as kReuse.
Fold(UseKind a,UseKind b)3520 static UseKind Fold(UseKind a, UseKind b) {
3521 // Without loss of generality, let `b` be the operation with the larger use
3522 // kind.
3523 if (b.kind < a.kind) {
3524 std::swap(a, b);
3525 }
3526 // If the kinds are different, return the smaller one, namely `a`.
3527 if (a.kind != b.kind) {
3528 return a;
3529 }
3530 // If the kinds are both kUse, check that they're the same permutation.
3531 if (a.kind == UseKind::kUse && b.kind == UseKind::kUse &&
3532 a.permutation_instr != b.permutation_instr) {
3533 return UseKind::kReuse;
3534 }
3535 return a; // They're the same.
3536 }
3537
3538 // Combines two UseKinds differently than Fold().
3539 //
3540 // This is the min operation on the lattice
3541 //
3542 // kNoUse < kReuse < kUse.
3543 //
3544 // If `a` and `b` are both kUse and one has a non-null permutation
3545 // instruction, returns kUse with that permutation. OTOH if both have
3546 // different, non-null permutation instructions, returns kReuse.
3547 //
3548 // You can think of this sort of as a conjunction, whereas Fold is sort of a
3549 // disjunction. FoldUseMandatory() says "no use" if either input isn't used,
3550 // whereas Fold() would say "use".
FoldUseMandatory(UseKind a,UseKind b)3551 static UseKind FoldUseMandatory(UseKind a, UseKind b) {
3552 if (a.kind == UseKind::kNoUse || b.kind == UseKind::kNoUse) {
3553 return UseKind::kNoUse;
3554 }
3555 if (a.kind == UseKind::kReuse || b.kind == UseKind::kReuse) {
3556 return UseKind::kReuse;
3557 }
3558 if (a.permutation_instr == b.permutation_instr) {
3559 return a; // They're the same.
3560 }
3561 if (b.permutation_instr == nullptr) {
3562 return a;
3563 }
3564 if (a.permutation_instr == nullptr) {
3565 return b;
3566 }
3567 return UseKind::kReuse;
3568 }
3569 };
3570
OperandElementUse(int64 operand_num) const3571 HloInstruction::UseKind HloInstruction::OperandElementUse(
3572 int64 operand_num) const {
3573 switch (opcode_) {
3574 case HloOpcode::kBitcast:
3575 // A bitcast that only adds or removes degenerate (i.e. size 1) dimensions
3576 // doesn't permute its elements, so it counts as a plain, non-permuting
3577 // use.
3578 return ShapeUtil::DropDegenerateDimensions(shape()) ==
3579 ShapeUtil::DropDegenerateDimensions(operand(0)->shape())
3580 ? UseKind::kUse
3581 : UseKind::Permuting(this);
3582 case HloOpcode::kConcatenate:
3583 case HloOpcode::kReshape:
3584 case HloOpcode::kReverse:
3585 case HloOpcode::kSlice:
3586 case HloOpcode::kTranspose:
3587 return UseKind::Permuting(this);
3588 case HloOpcode::kPad:
3589 // Pad reuses the padding value but not the padded array elements.
3590 return operand_num > 0 ? UseKind::kReuse : UseKind::Permuting(this);
3591 case HloOpcode::kReduce:
3592 // Reduce reuses the init values but not the operand array elements.
3593 return operand_num >= Cast<HloReduceInstruction>(this)->input_count()
3594 ? UseKind::kReuse
3595 : UseKind::Permuting(this);
3596 case HloOpcode::kFusion:
3597 // Uses the memoizing, recursive computation defined above.
3598 return FusionReusesParamElements::Compute(operand_num,
3599 *fused_expression_root());
3600 case HloOpcode::kDot:
3601 // Matrix-vector dots do not reuse the matrix operand.
3602 if (shape().dimensions_size() <= 1) {
3603 if ((operand_num == 0 && operand(1)->shape().rank() <= 1) ||
3604 (operand_num == 1 && operand(0)->shape().rank() <= 1)) {
3605 return UseKind::kUse;
3606 }
3607 }
3608 return UseKind::kReuse;
3609 case HloOpcode::kDynamicUpdateSlice:
3610 // Dynamic-update-slice reuses only start_indices.
3611 if (operand_num == 0 || operand_num == 1) {
3612 return UseKind::kUse;
3613 }
3614 return UseKind::kReuse;
3615 case HloOpcode::kGather:
3616 // Gather reads its indices in a linear fashion, and it permutes the
3617 // vector it's gathering from.
3618 return operand_num == 0 ? UseKind::kUse : UseKind::Permuting(this);
3619 default:
3620 return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
3621 }
3622 }
3623
3624 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const3625 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
3626 if (HloOpcode::kReshape != opcode_) {
3627 return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
3628 }
3629 return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
3630 shape_);
3631 }
3632
ToString(HloInstruction::FusionKind kind)3633 string ToString(HloInstruction::FusionKind kind) {
3634 switch (kind) {
3635 case HloInstruction::FusionKind::kLoop:
3636 return "kLoop";
3637 case HloInstruction::FusionKind::kInput:
3638 return "kInput";
3639 case HloInstruction::FusionKind::kOutput:
3640 return "kOutput";
3641 case HloInstruction::FusionKind::kCustom:
3642 return "kCustom";
3643 }
3644 }
3645
StringToFusionKind(const string & kind_name)3646 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
3647 const string& kind_name) {
3648 if (kind_name == "kLoop") {
3649 return HloInstruction::FusionKind::kLoop;
3650 }
3651 if (kind_name == "kInput") {
3652 return HloInstruction::FusionKind::kInput;
3653 }
3654 if (kind_name == "kOutput") {
3655 return HloInstruction::FusionKind::kOutput;
3656 }
3657 if (kind_name == "kCustom") {
3658 return HloInstruction::FusionKind::kCustom;
3659 }
3660 return InvalidArgument("Unknown fusion kind: %s", kind_name);
3661 }
3662
FrontendAttributesToString(const FrontendAttributes & frontend_attributes)3663 string FrontendAttributesToString(
3664 const FrontendAttributes& frontend_attributes) {
3665 std::vector<std::pair<string, string>> sorted_attributes(
3666 frontend_attributes.map().begin(), frontend_attributes.map().end());
3667 absl::c_sort(sorted_attributes);
3668 // Frontend attribute is a comma-separated list of attribute="value" pairs,
3669 // e.g., frontend_attributes={name="value_a",type="int32"}.
3670 const auto formatter = [](string* out,
3671 const std::pair<string, string>& item) {
3672 absl::StrAppend(out, item.first, "=\"", item.second, "\"");
3673 };
3674 return absl::StrFormat("{%s}",
3675 absl::StrJoin(sorted_attributes, ",", formatter));
3676 }
3677
PaddingConfigToString(const PaddingConfig & padding)3678 string PaddingConfigToString(const PaddingConfig& padding) {
3679 bool has_interior_padding =
3680 absl::c_any_of(padding.dimensions(),
3681 [](const PaddingConfig::PaddingConfigDimension& dim) {
3682 return dim.interior_padding() != 0;
3683 });
3684 return StrJoin(
3685 padding.dimensions(), "x",
3686 [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3687 StrAppend(
3688 out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3689 has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3690 });
3691 }
3692
RandomDistributionToString(const RandomDistribution & distribution)3693 string RandomDistributionToString(const RandomDistribution& distribution) {
3694 return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3695 }
RandomAlgorithmToString(const RandomAlgorithm & algorithm)3696 string RandomAlgorithmToString(const RandomAlgorithm& algorithm) {
3697 return absl::AsciiStrToLower(RandomAlgorithm_Name(algorithm));
3698 }
3699
PrecisionToString(const PrecisionConfig::Precision & precision)3700 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3701 return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3702 }
3703
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3704 string ConvolutionDimensionNumbersToString(
3705 const ConvolutionDimensionNumbers& dnums) {
3706 // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3707 // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3708 std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3709 lhs_dims[dnums.input_batch_dimension()] = 'b';
3710 lhs_dims[dnums.input_feature_dimension()] = 'f';
3711 for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3712 lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3713 }
3714
3715 std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3716 rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3717 rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3718 for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3719 rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3720 }
3721
3722 std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3723 output_dims[dnums.output_batch_dimension()] = 'b';
3724 output_dims[dnums.output_feature_dimension()] = 'f';
3725 for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3726 output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3727 }
3728
3729 return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3730 StrJoin(output_dims, ""));
3731 }
3732
ReplicaGroupsToString(const std::vector<ReplicaGroup> & replica_groups)3733 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups) {
3734 std::vector<string> replica_group_str;
3735 replica_group_str.reserve(replica_groups.size());
3736 for (const ReplicaGroup& group : replica_groups) {
3737 replica_group_str.push_back(
3738 StrCat("{", StrJoin(group.replica_ids(), ","), "}"));
3739 }
3740 return StrCat("{", StrJoin(replica_group_str, ","), "}");
3741 }
3742
StringToRandomAlgorithm(const string & name)3743 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name) {
3744 static std::unordered_map<string, RandomAlgorithm>* map = [] {
3745 static auto* map = new std::unordered_map<string, RandomAlgorithm>;
3746 for (int i = 0; i < RandomAlgorithm_ARRAYSIZE; i++) {
3747 if (RandomAlgorithm_IsValid(i)) {
3748 auto value = static_cast<RandomAlgorithm>(i);
3749 (*map)[RandomAlgorithmToString(value)] = value;
3750 }
3751 }
3752 return map;
3753 }();
3754 auto found = map->find(absl::AsciiStrToLower(name));
3755 if (found == map->end()) {
3756 return InvalidArgument("Unknown algorithm");
3757 }
3758 return found->second;
3759 }
3760
StringToRandomDistribution(const string & name)3761 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3762 static std::unordered_map<string, RandomDistribution>* map = [] {
3763 static auto* map = new std::unordered_map<string, RandomDistribution>;
3764 for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3765 if (RandomDistribution_IsValid(i)) {
3766 auto value = static_cast<RandomDistribution>(i);
3767 (*map)[RandomDistributionToString(value)] = value;
3768 }
3769 }
3770 return map;
3771 }();
3772 auto found = map->find(absl::AsciiStrToLower(name));
3773 if (found == map->end()) {
3774 return InvalidArgument("Unknown distribution");
3775 }
3776 return found->second;
3777 }
3778
StringToPrecision(const string & name)3779 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
3780 static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
3781 static auto* map =
3782 new std::unordered_map<string, PrecisionConfig::Precision>;
3783 for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
3784 if (PrecisionConfig::Precision_IsValid(i)) {
3785 auto value = static_cast<PrecisionConfig::Precision>(i);
3786 (*map)[PrecisionToString(value)] = value;
3787 }
3788 }
3789 return map;
3790 }();
3791 auto found = map->find(absl::AsciiStrToLower(name));
3792 if (found == map->end()) {
3793 return InvalidArgument("Unknown distribution");
3794 }
3795 return found->second;
3796 }
3797
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3798 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3799 return os << ToString(kind);
3800 }
3801
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const3802 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
3803 const HloInstruction* const& rhs) const {
3804 if (rhs == nullptr) {
3805 // Nothing compares less than nullptr.
3806 return false;
3807 }
3808 if (lhs == nullptr) {
3809 return true;
3810 }
3811 auto lhs_module = lhs->GetModule();
3812 auto rhs_module = rhs->GetModule();
3813 CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
3814 (lhs_module != nullptr && rhs_module != nullptr));
3815 if (lhs_module != nullptr &&
3816 lhs_module->unique_id() != rhs_module->unique_id()) {
3817 return lhs_module->unique_id() < rhs_module->unique_id();
3818 }
3819 return lhs->unique_id() < rhs->unique_id();
3820 }
3821
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const3822 Status HloInstruction::GetBackendConfigInternal(
3823 tensorflow::protobuf::Message* proto) const {
3824 proto->Clear();
3825
3826 // Empty string does not parse as valid JSON, but it's a valid backend config,
3827 // corresponding to the empty proto.
3828 if (backend_config_.empty()) {
3829 return Status::OK();
3830 }
3831 return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
3832 }
3833
set_backend_config(const tensorflow::protobuf::Message & proto)3834 Status HloInstruction::set_backend_config(
3835 const tensorflow::protobuf::Message& proto) {
3836 TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
3837 return Status::OK();
3838 }
3839
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)3840 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
3841 const tensorflow::protobuf::Message& proto) {
3842 string ret;
3843 // Pass ignore_accuracy_loss = true because estimated_cycles field can be
3844 // INT64_MAX. If ignore_accuracy_loss = false and estimated_cycles =
3845 // INT64_MAX, JsonFormat will return an error status, although there is no
3846 // accuracy loss for int64.
3847 TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(
3848 proto, &ret, /*ignore_accuracy_loss=*/true));
3849 return ret;
3850 }
3851
precision_config() const3852 const PrecisionConfig& HloInstruction::precision_config() const {
3853 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3854 return convolution->precision_config();
3855 }
3856 if (auto* dot = DynCast<HloDotInstruction>(this)) {
3857 return dot->precision_config();
3858 }
3859
3860 if (auto* custom_call = DynCast<HloCustomCallInstruction>(this)) {
3861 return custom_call->precision_config();
3862 }
3863 LOG(FATAL) << "Unimplemented method.";
3864 }
3865
mutable_precision_config()3866 PrecisionConfig* HloInstruction::mutable_precision_config() {
3867 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3868 return convolution->mutable_precision_config();
3869 }
3870 if (auto* dot = DynCast<HloDotInstruction>(this)) {
3871 return dot->mutable_precision_config();
3872 }
3873 LOG(FATAL) << "Unimplemented method.";
3874 }
3875
GetModule() const3876 HloModule* HloInstruction::GetModule() const {
3877 if (parent_) {
3878 return parent_->parent();
3879 }
3880 return nullptr;
3881 }
3882
UniquifyName(NameUniquer * name_uniquer)3883 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3884 string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3885 name_ = name_uniquer->GetUniqueName(name_);
3886 }
3887
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3888 void HloInstruction::set_outer_dimension_partitions(
3889 const std::vector<int64>& outer_dimension_partitions) {
3890 outer_dimension_partitions_ = outer_dimension_partitions;
3891 }
3892
3893 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const3894 int64 HloInstruction::feature_index() const {
3895 return Cast<HloBatchNormInstruction>(this)->feature_index();
3896 }
3897
epsilon() const3898 float HloInstruction::epsilon() const {
3899 return Cast<HloBatchNormInstruction>(this)->epsilon();
3900 }
3901
fft_type() const3902 FftType HloInstruction::fft_type() const {
3903 return Cast<HloFftInstruction>(this)->fft_type();
3904 }
3905
fft_length() const3906 const std::vector<int64>& HloInstruction::fft_length() const {
3907 return Cast<HloFftInstruction>(this)->fft_length();
3908 }
3909
concatenate_dimension() const3910 int64 HloInstruction::concatenate_dimension() const {
3911 return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
3912 }
3913
dimension() const3914 int64 HloInstruction::dimension() const {
3915 if (auto set_size = DynCast<HloSetDimensionSizeInstruction>(this)) {
3916 return set_size->dimension();
3917 }
3918 return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
3919 }
3920
inferred_dimension() const3921 int64 HloInstruction::inferred_dimension() const {
3922 return Cast<HloReshapeInstruction>(this)->inferred_dimension();
3923 }
3924
IsRank2Transpose() const3925 bool HloInstruction::IsRank2Transpose() const {
3926 auto transpose = DynCast<HloTransposeInstruction>(this);
3927 return transpose != nullptr && transpose->IsRank2Transpose();
3928 }
3929
slice_starts(int64 dimension) const3930 int64 HloInstruction::slice_starts(int64 dimension) const {
3931 return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
3932 }
3933
slice_starts() const3934 const std::vector<int64>& HloInstruction::slice_starts() const {
3935 return Cast<HloSliceInstruction>(this)->slice_starts();
3936 }
3937
mutable_slice_starts()3938 std::vector<int64>* HloInstruction::mutable_slice_starts() {
3939 return Cast<HloSliceInstruction>(this)->mutable_slice_starts();
3940 }
3941
slice_limits(int64 dimension) const3942 int64 HloInstruction::slice_limits(int64 dimension) const {
3943 return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
3944 }
3945
slice_limits() const3946 const std::vector<int64>& HloInstruction::slice_limits() const {
3947 return Cast<HloSliceInstruction>(this)->slice_limits();
3948 }
3949
mutable_slice_limits()3950 std::vector<int64>* HloInstruction::mutable_slice_limits() {
3951 return Cast<HloSliceInstruction>(this)->mutable_slice_limits();
3952 }
3953
slice_strides(int64 dimension) const3954 int64 HloInstruction::slice_strides(int64 dimension) const {
3955 return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
3956 }
3957
slice_strides() const3958 const std::vector<int64>& HloInstruction::slice_strides() const {
3959 return Cast<HloSliceInstruction>(this)->slice_strides();
3960 }
3961
mutable_slice_strides()3962 std::vector<int64>* HloInstruction::mutable_slice_strides() {
3963 return Cast<HloSliceInstruction>(this)->mutable_slice_strides();
3964 }
3965
literal() const3966 const Literal& HloInstruction::literal() const {
3967 return Cast<HloConstantInstruction>(this)->literal();
3968 }
3969
IsConstant() const3970 bool HloInstruction::IsConstant() const {
3971 return DynCast<HloConstantInstruction>(this) != nullptr;
3972 }
3973
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3974 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3975 const ShapeIndex& shape_index) {
3976 Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
3977 }
3978
TracingTag() const3979 string HloInstruction::TracingTag() const {
3980 return Cast<HloTraceInstruction>(this)->TracingTag();
3981 }
3982
AddFusionOperand(HloInstruction * new_operand)3983 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
3984 return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
3985 }
3986
3987 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)3988 void HloInstruction::MergeFusionInstruction(
3989 HloInstruction* instruction_to_merge) {
3990 return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
3991 Cast<HloFusionInstruction>(instruction_to_merge));
3992 }
3993
3994 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)3995 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
3996 HloInstruction* instruction_to_merge) {
3997 return Cast<HloFusionInstruction>(this)
3998 ->MergeFusionInstructionIntoMultiOutput(
3999 Cast<HloFusionInstruction>(instruction_to_merge));
4000 }
4001
FuseInstruction(HloInstruction * instruction_to_fuse)4002 HloInstruction* HloInstruction::FuseInstruction(
4003 HloInstruction* instruction_to_fuse) {
4004 return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
4005 }
4006
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)4007 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
4008 HloInstruction* instruction_to_fuse) {
4009 return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
4010 instruction_to_fuse);
4011 }
4012
fused_instructions_computation() const4013 HloComputation* HloInstruction::fused_instructions_computation() const {
4014 return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
4015 }
4016
fused_expression_root() const4017 HloInstruction* HloInstruction::fused_expression_root() const {
4018 return Cast<HloFusionInstruction>(this)->fused_expression_root();
4019 }
4020
4021 const tensorflow::gtl::iterator_range<UnwrappingIterator<
4022 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const4023 HloInstruction::fused_instructions() const {
4024 return Cast<HloFusionInstruction>(this)->fused_instructions();
4025 }
4026
4027 const tensorflow::gtl::iterator_range<
4028 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()4029 HloInstruction::fused_instructions() {
4030 return Cast<HloFusionInstruction>(this)->fused_instructions();
4031 }
4032
fused_instruction_count() const4033 int64 HloInstruction::fused_instruction_count() const {
4034 return Cast<HloFusionInstruction>(this)->fused_instruction_count();
4035 }
4036
fused_parameter(int64 parameter_number) const4037 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
4038 return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
4039 }
4040
fused_parameters() const4041 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
4042 return Cast<HloFusionInstruction>(this)->fused_parameters();
4043 }
4044
IsMultiOutputFusion() const4045 const bool HloInstruction::IsMultiOutputFusion() const {
4046 const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
4047 return fusion != nullptr && fusion->IsMultiOutputFusion();
4048 }
4049
fusion_kind() const4050 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
4051 return Cast<HloFusionInstruction>(this)->fusion_kind();
4052 }
4053
set_fusion_kind(FusionKind kind)4054 void HloInstruction::set_fusion_kind(FusionKind kind) {
4055 return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
4056 }
4057
random_distribution() const4058 RandomDistribution HloInstruction::random_distribution() const {
4059 return Cast<HloRngInstruction>(this)->random_distribution();
4060 }
4061
parameter_number() const4062 int64 HloInstruction::parameter_number() const {
4063 return Cast<HloParameterInstruction>(this)->parameter_number();
4064 }
4065
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)4066 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4067 absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
4068 return Cast<HloParameterInstruction>(this)
4069 ->set_parameter_replicated_at_leaf_buffers(
4070 parameter_replicated_at_leaf_buffers);
4071 }
4072
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)4073 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
4074 const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
4075 return Cast<HloParameterInstruction>(this)
4076 ->set_parameter_replicated_at_leaf_buffers(
4077 parameter_replicated_at_leaf_buffers);
4078 }
4079
4080 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const4081 HloInstruction::parameter_replicated_at_leaf_buffers() const {
4082 return Cast<HloParameterInstruction>(this)
4083 ->parameter_replicated_at_leaf_buffers();
4084 }
4085
tuple_index() const4086 int64 HloInstruction::tuple_index() const {
4087 return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
4088 }
4089
set_tuple_index(int64 new_tuple_index)4090 void HloInstruction::set_tuple_index(int64 new_tuple_index) {
4091 return Cast<HloGetTupleElementInstruction>(this)->set_tuple_index(
4092 new_tuple_index);
4093 }
4094
exponent_bits() const4095 int32 HloInstruction::exponent_bits() const {
4096 return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
4097 }
4098
mantissa_bits() const4099 int32 HloInstruction::mantissa_bits() const {
4100 return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
4101 }
4102
infeed_config() const4103 string HloInstruction::infeed_config() const {
4104 return Cast<HloInfeedInstruction>(this)->infeed_config();
4105 }
4106
set_infeed_config(const string & config)4107 void HloInstruction::set_infeed_config(const string& config) {
4108 return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
4109 }
4110
outfeed_shape() const4111 const Shape& HloInstruction::outfeed_shape() const {
4112 return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
4113 }
4114
mutable_outfeed_shape()4115 Shape* HloInstruction::mutable_outfeed_shape() {
4116 return Cast<HloOutfeedInstruction>(this)->mutable_outfeed_shape();
4117 }
4118
outfeed_config() const4119 const string& HloInstruction::outfeed_config() const {
4120 return Cast<HloOutfeedInstruction>(this)->outfeed_config();
4121 }
4122
set_outfeed_config(const string & config)4123 void HloInstruction::set_outfeed_config(const string& config) {
4124 return Cast<HloOutfeedInstruction>(this)->set_outfeed_config(config);
4125 }
4126
replica_groups() const4127 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
4128 return Cast<HloCollectiveInstruction>(this)->replica_groups();
4129 }
4130
4131 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const4132 HloInstruction::source_target_pairs() const {
4133 return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
4134 }
4135
channel_id() const4136 absl::optional<int64> HloInstruction::channel_id() const {
4137 return Cast<HloChannelInstruction>(this)->channel_id();
4138 }
4139
set_channel_id(const absl::optional<int64> & channel_id)4140 void HloInstruction::set_channel_id(const absl::optional<int64>& channel_id) {
4141 return Cast<HloChannelInstruction>(this)->set_channel_id(channel_id);
4142 }
4143
4144 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const4145 HloInstruction::convolution_dimension_numbers() const {
4146 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4147 return convolution->convolution_dimension_numbers();
4148 }
4149 if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4150 return custom_call->convolution_dimension_numbers();
4151 }
4152 LOG(FATAL) << "Unimplemented method.";
4153 }
4154
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)4155 void HloInstruction::set_convolution_dimension_numbers(
4156 const ConvolutionDimensionNumbers& dnums) {
4157 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4158 convolution->set_convolution_dimension_numbers(dnums);
4159 } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
4160 custom_call->set_convolution_dimension_numbers(dnums);
4161 } else {
4162 LOG(FATAL) << "Unimplemented method.";
4163 }
4164 }
4165
feature_group_count() const4166 int64 HloInstruction::feature_group_count() const {
4167 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4168 return convolution->feature_group_count();
4169 }
4170 return Cast<HloCustomCallInstruction>(this)->feature_group_count();
4171 }
4172
set_feature_group_count(int64 feature_group_count)4173 void HloInstruction::set_feature_group_count(int64 feature_group_count) {
4174 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4175 return convolution->set_feature_group_count(feature_group_count);
4176 }
4177 Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
4178 feature_group_count);
4179 }
4180
batch_group_count() const4181 int64 HloInstruction::batch_group_count() const {
4182 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4183 return convolution->batch_group_count();
4184 }
4185 return Cast<HloCustomCallInstruction>(this)->batch_group_count();
4186 }
4187
set_batch_group_count(int64 batch_group_count)4188 void HloInstruction::set_batch_group_count(int64 batch_group_count) {
4189 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
4190 return convolution->set_batch_group_count(batch_group_count);
4191 }
4192 Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
4193 batch_group_count);
4194 }
4195
select() const4196 HloComputation* HloInstruction::select() const {
4197 return Cast<HloSelectAndScatterInstruction>(this)->select();
4198 }
4199
scatter() const4200 HloComputation* HloInstruction::scatter() const {
4201 return Cast<HloSelectAndScatterInstruction>(this)->scatter();
4202 }
4203
set_select(HloComputation * computation)4204 void HloInstruction::set_select(HloComputation* computation) {
4205 return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
4206 }
4207
set_scatter(HloComputation * computation)4208 void HloInstruction::set_scatter(HloComputation* computation) {
4209 return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
4210 }
4211
custom_call_target() const4212 const string& HloInstruction::custom_call_target() const {
4213 return Cast<HloCustomCallInstruction>(this)->custom_call_target();
4214 }
4215
padding_config() const4216 const PaddingConfig& HloInstruction::padding_config() const {
4217 return Cast<HloPadInstruction>(this)->padding_config();
4218 }
4219
padding_type() const4220 PaddingType HloInstruction::padding_type() const {
4221 return Cast<HloCustomCallInstruction>(this)->padding_type();
4222 }
4223
mutable_padding_config()4224 PaddingConfig* HloInstruction::mutable_padding_config() {
4225 return Cast<HloPadInstruction>(this)->mutable_padding_config();
4226 }
4227
slice_sizes(int64 dimension) const4228 int64 HloInstruction::slice_sizes(int64 dimension) const {
4229 return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
4230 }
4231
dynamic_slice_sizes() const4232 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
4233 return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
4234 }
4235
gather_dimension_numbers() const4236 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
4237 return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
4238 }
4239
gather_slice_sizes() const4240 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
4241 return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
4242 }
4243
scatter_dimension_numbers() const4244 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
4245 const {
4246 return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
4247 }
4248
dot_dimension_numbers() const4249 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
4250 return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
4251 }
4252
operand_side_metadata() const4253 const DomainMetadata& HloInstruction::operand_side_metadata() const {
4254 return Cast<HloDomainInstruction>(this)->operand_side_metadata();
4255 }
4256
user_side_metadata() const4257 const DomainMetadata& HloInstruction::user_side_metadata() const {
4258 return Cast<HloDomainInstruction>(this)->user_side_metadata();
4259 }
4260
is_cross_program_prefetch() const4261 bool HloInstruction::is_cross_program_prefetch() const {
4262 return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
4263 }
4264
comparison_direction() const4265 ComparisonDirection HloInstruction::comparison_direction() const {
4266 return Cast<HloCompareInstruction>(this)->direction();
4267 }
4268
triangular_solve_options() const4269 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
4270 return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
4271 }
4272
cholesky_options() const4273 const CholeskyOptions& HloInstruction::cholesky_options() const {
4274 return Cast<HloCholeskyInstruction>(this)->cholesky_options();
4275 }
4276
4277 } // namespace xla
4278