1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
17
18 #include <algorithm>
19 #include <ostream>
20 #include <set>
21 #include <unordered_set>
22 #include <utility>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/ascii.h"
30 #include "absl/strings/escaping.h"
31 #include "absl/strings/numbers.h"
32 #include "absl/strings/str_cat.h"
33 #include "absl/strings/str_join.h"
34 #include "absl/types/span.h"
35 #include "tensorflow/compiler/xla/layout_util.h"
36 #include "tensorflow/compiler/xla/literal.h"
37 #include "tensorflow/compiler/xla/protobuf_util.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
40 #include "tensorflow/compiler/xla/service/hlo_computation.h"
41 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
42 #include "tensorflow/compiler/xla/service/hlo_module.h"
43 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
44 #include "tensorflow/compiler/xla/service/name_uniquer.h"
45 #include "tensorflow/compiler/xla/shape_util.h"
46 #include "tensorflow/compiler/xla/status_macros.h"
47 #include "tensorflow/compiler/xla/types.h"
48 #include "tensorflow/compiler/xla/util.h"
49 #include "tensorflow/core/lib/core/errors.h"
50 #include "tensorflow/core/lib/gtl/map_util.h"
51 #include "tensorflow/core/platform/human_readable_json.h"
52 #include "tensorflow/core/platform/logging.h"
53
54 namespace xla {
55
56 using absl::CEscape;
57 using absl::StrAppend;
58 using absl::StrCat;
59 using absl::StrJoin;
60
61 /* static */
CreateFromProto(const HloInstructionProto & proto,const absl::flat_hash_map<int64,HloInstruction * > & instruction_map,const absl::flat_hash_map<int64,HloComputation * > & computation_map)62 StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
63 const HloInstructionProto& proto,
64 const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
65 const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
66 TF_RET_CHECK(!proto.opcode().empty());
67 HloOpcode opcode;
68 auto opcode_or = StringToHloOpcode(proto.opcode());
69 absl::optional<ComparisonDirection> comparison_direction;
70 if (opcode_or.ok()) {
71 opcode = opcode_or.ConsumeValueOrDie();
72 } else {
73 // Unknown opcode. Try auto-upgrading deprecated "less-than",
74 // "greater-than", etc opcodes, which are now rolled into the kCompare
75 // opcode.
76 if (proto.opcode() == "equal-to") {
77 comparison_direction = ComparisonDirection::kEq;
78 } else if (proto.opcode() == "not-equal-to") {
79 comparison_direction = ComparisonDirection::kNe;
80 } else if (proto.opcode() == "greater-than-or-equal-to") {
81 comparison_direction = ComparisonDirection::kGe;
82 } else if (proto.opcode() == "greater-than") {
83 comparison_direction = ComparisonDirection::kGt;
84 } else if (proto.opcode() == "less-than-or-equal-to") {
85 comparison_direction = ComparisonDirection::kLe;
86 } else if (proto.opcode() == "less-than") {
87 comparison_direction = ComparisonDirection::kLt;
88 }
89 if (comparison_direction) {
90 opcode = HloOpcode::kCompare;
91 } else {
92 return InvalidArgument("Unknown opcode: %s", proto.opcode());
93 }
94 }
95
96 TF_RET_CHECK(proto.has_shape());
97
98 std::unique_ptr<HloInstruction> instruction;
99 const auto operands = [&instruction_map, &proto](int index) {
100 return instruction_map.at(proto.operand_ids(index));
101 };
102 const auto all_operands = [&instruction_map, &proto]() {
103 std::vector<HloInstruction*> result(proto.operand_ids_size());
104 std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
105 result.begin(), [&instruction_map](int64 operand_id) {
106 return instruction_map.at(operand_id);
107 });
108 return result;
109 };
110 const auto computations = [&computation_map, &proto](int index) {
111 return computation_map.at(proto.called_computation_ids(index));
112 };
113 const auto all_computations = [&computation_map, &proto]() {
114 std::vector<HloComputation*> result(proto.called_computation_ids_size());
115 std::transform(proto.called_computation_ids().begin(),
116 proto.called_computation_ids().end(), result.begin(),
117 [&computation_map](int64 computation_id) {
118 return computation_map.at(computation_id);
119 });
120 return result;
121 };
122
123 TF_RET_CHECK(
124 absl::c_all_of(proto.operand_ids(),
125 [&](int64 id) { return instruction_map.contains(id); }))
126 << proto.name() << " instruction contains invalid operand id(s)";
127
128 TF_RET_CHECK(
129 absl::c_all_of(proto.called_computation_ids(),
130 [&](int64 id) { return computation_map.contains(id); }))
131 << proto.name() << " instruction references invalid computation id(s)";
132
133 Shape shape(proto.shape());
134 TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
135
136 absl::optional<int> arity = HloOpcodeArity(opcode);
137 if (arity) {
138 TF_RET_CHECK(proto.operand_ids_size() == *arity)
139 << proto.opcode() << " instruction should have " << *arity
140 << " operands but sees " << proto.operand_ids_size();
141 }
142
143 switch (opcode) {
144 // Ops migrated to subclasses.
145 case HloOpcode::kBatchNormTraining:
146 instruction =
147 CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
148 proto.epsilon(), proto.feature_index());
149 break;
150 case HloOpcode::kBatchNormInference:
151 instruction = CreateBatchNormInference(
152 shape, operands(0), operands(1), operands(2), operands(3),
153 operands(4), proto.epsilon(), proto.feature_index());
154 break;
155 case HloOpcode::kBatchNormGrad:
156 instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
157 operands(2), operands(3), operands(4),
158 proto.epsilon(), proto.feature_index());
159 break;
160 case HloOpcode::kFft: {
161 std::vector<int64> fft_length(proto.fft_length().begin(),
162 proto.fft_length().end());
163 instruction = CreateFft(shape, operands(0), proto.fft_type(),
164 absl::Span<const int64>(fft_length));
165 break;
166 }
167 case HloOpcode::kCompare: {
168 // Auto-upgraded from deprecated opcode skips the following.
169 if (!comparison_direction) {
170 TF_ASSIGN_OR_RETURN(
171 comparison_direction,
172 StringToComparisonDirection(proto.comparison_direction()));
173 }
174 instruction =
175 CreateCompare(shape, operands(0), operands(1), *comparison_direction);
176 break;
177 }
178 case HloOpcode::kTriangularSolve: {
179 instruction = CreateTriangularSolve(shape, operands(0), operands(1),
180 proto.triangular_solve_options());
181 break;
182 }
183 case HloOpcode::kCholesky: {
184 instruction =
185 CreateCholesky(shape, operands(0), proto.cholesky_options());
186 break;
187 }
188 case HloOpcode::kSend:
189 instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
190 proto.is_host_transfer());
191 break;
192 case HloOpcode::kSendDone:
193 instruction = CreateSendDone(operands(0), proto.is_host_transfer());
194 break;
195 case HloOpcode::kRecv:
196 instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
197 proto.channel_id(), proto.is_host_transfer());
198 break;
199 case HloOpcode::kRecvDone:
200 instruction = CreateRecvDone(operands(0), proto.is_host_transfer());
201 break;
202 case HloOpcode::kReverse:
203 instruction = CreateReverse(shape, operands(0),
204 std::vector<int64>(proto.dimensions().begin(),
205 proto.dimensions().end()));
206 break;
207 case HloOpcode::kConcatenate:
208 TF_RET_CHECK(proto.dimensions_size() == 1)
209 << "Concatenate instruction should have 1 dimension but sees "
210 << proto.dimensions_size();
211 instruction =
212 CreateConcatenate(shape, all_operands(), proto.dimensions(0));
213 break;
214 case HloOpcode::kConditional: {
215 TF_RET_CHECK(proto.called_computation_ids_size() > 0)
216 << "conditional should have at least 1 called computation";
217 if (operands(0)->shape().element_type() == PRED) {
218 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
219 << "conditional should have exactly 2 called computations but got "
220 << proto.called_computation_ids_size();
221 }
222 TF_RET_CHECK(proto.operand_ids_size() ==
223 proto.called_computation_ids_size() + 1)
224 << "conditional should have one branch_index operand plus one "
225 "operand per called computation but got "
226 << proto.operand_ids_size() << " operands for "
227 << proto.called_computation_ids_size() << " branch computations";
228 auto cond_operands = all_operands();
229 instruction =
230 CreateConditional(shape, cond_operands[0], all_computations(),
231 absl::MakeSpan(cond_operands).subspan(1));
232 break;
233 }
234 case HloOpcode::kReduce:
235 TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
236 << "Reduce instruction should have an even number of operands but "
237 "sees "
238 << proto.operand_ids_size();
239 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
240 << "Reduce instruction should have 1 called computation but sees "
241 << proto.called_computation_ids_size();
242 {
243 const auto reduce_operands = all_operands();
244 auto inputs = absl::MakeSpan(reduce_operands)
245 .subspan(0, reduce_operands.size() / 2);
246 auto init_values =
247 absl::MakeSpan(reduce_operands)
248 .subspan(reduce_operands.size() / 2, reduce_operands.size());
249 instruction =
250 CreateReduce(shape, inputs, init_values,
251 std::vector<int64>(proto.dimensions().begin(),
252 proto.dimensions().end()),
253 computations(0));
254 }
255 break;
256 case HloOpcode::kSort: {
257 TF_RET_CHECK(proto.operand_ids_size() >= 1)
258 << "Sort instruction should have at least 1 operand but has "
259 << proto.operand_ids_size();
260 TF_RET_CHECK(proto.dimensions().size() == 1)
261 << "Sort instruction should have 1 dimension";
262 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
263 << "Sort instruction should one called computation but sees "
264 << proto.called_computation_ids_size();
265 auto sort_operands = all_operands();
266 instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
267 computations(0), proto.is_stable());
268 break;
269 }
270 case HloOpcode::kTranspose:
271 instruction =
272 CreateTranspose(shape, operands(0),
273 std::vector<int64>(proto.dimensions().begin(),
274 proto.dimensions().end()));
275 break;
276 case HloOpcode::kBroadcast:
277 instruction =
278 CreateBroadcast(shape, operands(0),
279 std::vector<int64>(proto.dimensions().begin(),
280 proto.dimensions().end()));
281 break;
282 case HloOpcode::kMap:
283 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
284 << "Map instruction should have 1 called computation but sees "
285 << proto.called_computation_ids_size();
286 instruction = CreateMap(shape, all_operands(), computations(0));
287 break;
288 case HloOpcode::kSlice: {
289 std::vector<int64> slice_starts, slice_limits, slice_strides;
290 for (const HloInstructionProto::SliceDimensions& slice_dimensions :
291 proto.slice_dimensions()) {
292 slice_starts.push_back(slice_dimensions.start());
293 slice_limits.push_back(slice_dimensions.limit());
294 slice_strides.push_back(slice_dimensions.stride());
295 }
296 instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
297 slice_strides);
298 break;
299 }
300 case HloOpcode::kConstant: {
301 // TODO(b/110214922): Revert this to CHECK(proto.has_literal()).
302 if (proto.has_literal()) {
303 TF_ASSIGN_OR_RETURN(auto literal,
304 Literal::CreateFromProto(proto.literal()));
305 instruction = CreateConstant(std::move(literal));
306 } else {
307 instruction = absl::make_unique<HloConstantInstruction>(shape);
308 }
309 break;
310 }
311 case HloOpcode::kTrace: {
312 TF_RET_CHECK(proto.has_literal());
313 TF_ASSIGN_OR_RETURN(auto literal,
314 Literal::CreateFromProto(proto.literal()));
315 instruction = CreateTrace(literal.GetR1U8AsString(), operands(0));
316 break;
317 }
318 case HloOpcode::kFusion: {
319 // In the proto, fused computations are held exclusively within the
320 // HloInstructionProto and do not appear as an HloComputationProto within
321 // the HloModuleProto.
322 TF_RET_CHECK(!proto.fusion_kind().empty());
323 TF_ASSIGN_OR_RETURN(FusionKind fusion_kind,
324 StringToFusionKind(proto.fusion_kind()));
325
326 // Find the fused computation and set its fusion instruction.
327 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
328 << "Expect 1 called computation for fusion instruction but sees "
329 << proto.called_computation_ids_size();
330 const int64 fusion_id = proto.called_computation_ids(0);
331 auto* fused_computation =
332 tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
333 TF_RET_CHECK(fused_computation != nullptr)
334 << "No fusion computation with id " << fusion_id;
335 instruction =
336 CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
337 break;
338 }
339 case HloOpcode::kRng:
340 instruction = CreateRng(shape, proto.distribution(), all_operands());
341 break;
342 case HloOpcode::kParameter:
343 instruction =
344 CreateParameter(proto.parameter_number(), shape, proto.name());
345 if (!proto.parameter_replication().replicated_at_leaf_buffers().empty()) {
346 instruction->set_parameter_replicated_at_leaf_buffers(
347 proto.parameter_replication().replicated_at_leaf_buffers());
348 }
349 break;
350 case HloOpcode::kGetTupleElement:
351 instruction =
352 CreateGetTupleElement(shape, operands(0), proto.tuple_index());
353 break;
354 case HloOpcode::kReducePrecision:
355 instruction = CreateReducePrecision(
356 shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
357 break;
358 case HloOpcode::kInfeed: {
359 TF_RET_CHECK(shape.IsTuple() &&
360 (ShapeUtil::TupleElementCount(shape) == 2))
361 << "Infeed should have a tuple shape with 2 operands, but has: "
362 << shape;
363 const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
364 instruction =
365 CreateInfeed(data_shape, operands(0), proto.infeed_config());
366 } break;
367 case HloOpcode::kOutfeed: {
368 Shape outfeed_shape(proto.outfeed_shape());
369 TF_RETURN_IF_ERROR(
370 ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
371 instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
372 proto.outfeed_config());
373 break;
374 }
375 case HloOpcode::kAllReduce: {
376 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
377 << "AllReduce should have 1 called computation but sees "
378 << proto.called_computation_ids_size();
379 absl::optional<int64> all_reduce_id;
380 if (proto.all_reduce_id() > 0) {
381 all_reduce_id = proto.all_reduce_id();
382 }
383 instruction = CreateAllReduce(
384 shape, all_operands(), computations(0),
385 /*replica_groups=*/
386 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
387 proto.replica_groups().end()),
388 /*barrier=*/proto.all_reduce_barrier(),
389 /*all_reduce_id=*/all_reduce_id);
390 break;
391 }
392 case HloOpcode::kAllToAll: {
393 instruction = CreateAllToAll(
394 shape, all_operands(),
395 /*replica_groups=*/
396 std::vector<ReplicaGroup>(proto.replica_groups().begin(),
397 proto.replica_groups().end()));
398 break;
399 }
400 case HloOpcode::kCollectivePermute: {
401 std::vector<std::pair<int64, int64>> source_target_pairs(
402 proto.source_target_pairs_size());
403 for (int i = 0; i < source_target_pairs.size(); i++) {
404 source_target_pairs[i].first = proto.source_target_pairs(i).source();
405 source_target_pairs[i].second = proto.source_target_pairs(i).target();
406 }
407 instruction =
408 CreateCollectivePermute(shape, operands(0), source_target_pairs);
409 break;
410 }
411 case HloOpcode::kReplicaId: {
412 instruction = CreateReplicaId();
413 break;
414 }
415 case HloOpcode::kConvolution: {
416 TF_RET_CHECK(proto.has_window());
417 TF_RET_CHECK(proto.has_convolution_dimension_numbers());
418 PrecisionConfig precision_config = proto.precision_config();
419 precision_config.mutable_operand_precision()->Resize(
420 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
421 instruction = CreateConvolve(
422 shape, operands(0), operands(1),
423 std::max<int64>(proto.feature_group_count(), 1),
424 std::max<int64>(proto.batch_group_count(), 1), proto.window(),
425 proto.convolution_dimension_numbers(), precision_config);
426 break;
427 }
428 case HloOpcode::kReduceWindow:
429 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
430 << "ReduceWindow should have 1 called computation but sees "
431 << proto.called_computation_ids_size();
432 instruction = CreateReduceWindow(shape, operands(0), operands(1),
433 proto.window(), computations(0));
434 break;
435 case HloOpcode::kSelectAndScatter:
436 TF_RET_CHECK(proto.called_computation_ids_size() == 2)
437 << "SelectAndScatter should have 2 called computations but sees "
438 << proto.called_computation_ids_size();
439 instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
440 proto.window(), operands(1),
441 operands(2), computations(1));
442 break;
443 case HloOpcode::kCustomCall:
444 if (proto.constrain_layout()) {
445 // A proto RepeatedPtrField cannot be converted to a Span (it is a
446 // vector of pointers essentially) so create a vector of shapes to pass
447 // in.
448 std::vector<Shape> operand_shapes;
449 for (const ShapeProto& shape_proto :
450 proto.operand_shapes_with_layout()) {
451 operand_shapes.emplace_back(shape_proto);
452 }
453 instruction =
454 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
455 operand_shapes, proto.custom_call_opaque());
456 } else {
457 instruction =
458 CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
459 proto.custom_call_opaque());
460 }
461 if (proto.has_window()) {
462 static_cast<HloCustomCallInstruction*>(instruction.get())
463 ->set_window(proto.window());
464 }
465 if (proto.has_convolution_dimension_numbers()) {
466 static_cast<HloCustomCallInstruction*>(instruction.get())
467 ->set_convolution_dimension_numbers(
468 proto.convolution_dimension_numbers());
469 }
470 static_cast<HloCustomCallInstruction*>(instruction.get())
471 ->set_feature_group_count(
472 std::max(static_cast<int64>(proto.feature_group_count()), 1LL));
473 static_cast<HloCustomCallInstruction*>(instruction.get())
474 ->set_batch_group_count(
475 std::max(static_cast<int64>(proto.batch_group_count()), 1LL));
476 break;
477 case HloOpcode::kPad:
478 TF_RET_CHECK(proto.has_padding_config());
479 instruction =
480 CreatePad(shape, operands(0), operands(1), proto.padding_config());
481 break;
482 case HloOpcode::kDynamicSlice: {
483 std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
484 absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
485 TF_RET_CHECK(proto.operand_ids_size() >= 1)
486 << "DynamicSlice instruction should have at least 1 operands but "
487 "sees "
488 << proto.operand_ids_size();
489 // TODO(b/118437727): Old form, make the check unconditional.
490 if (proto.operand_ids_size() != 2 || operands(1)->shape().rank() != 1) {
491 auto expected_operands = 1 + operands(0)->shape().rank();
492 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
493 << "DynamicSlice instruction should have " << expected_operands
494 << " operands, but has " << proto.operand_ids_size();
495 }
496 const auto& operand_vector = all_operands();
497 instruction = CreateDynamicSlice(
498 shape, operands(0), absl::MakeSpan(operand_vector).subspan(1),
499 slice_sizes);
500 break;
501 }
502 case HloOpcode::kDynamicUpdateSlice: {
503 TF_RET_CHECK(proto.operand_ids_size() >= 2)
504 << "DynamicUpdateSlice instruction should have at least 2 operands "
505 "but sees "
506 << proto.operand_ids_size();
507 // TODO(b/118437727): Old form, make the check unconditional.
508 if (proto.operand_ids_size() != 3 || operands(2)->shape().rank() != 1) {
509 auto expected_operands = 2 + operands(0)->shape().rank();
510 TF_RET_CHECK(proto.operand_ids_size() == expected_operands)
511 << "DynamicUpdateSlice instruction should have "
512 << expected_operands << " operands, but has "
513 << proto.operand_ids_size();
514 }
515 const auto& operand_vector = all_operands();
516 instruction =
517 CreateDynamicUpdateSlice(shape, operands(0), operands(1),
518 absl::MakeSpan(operand_vector).subspan(2));
519
520 break;
521 }
522 case HloOpcode::kGather: {
523 TF_RET_CHECK(proto.has_gather_dimension_numbers())
524 << "Gather instruction should have GatherDimensionNumbers set.";
525 std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers =
526 absl::make_unique<GatherDimensionNumbers>(
527 proto.gather_dimension_numbers());
528 std::vector<int64> gather_slice_sizes;
529 for (int64 bound : proto.gather_slice_sizes()) {
530 gather_slice_sizes.push_back(bound);
531 }
532 instruction = CreateGather(shape, operands(0), operands(1),
533 *gather_dimension_numbers, gather_slice_sizes);
534 break;
535 }
536 case HloOpcode::kScatter: {
537 TF_RET_CHECK(proto.has_scatter_dimension_numbers())
538 << "Scatter instruction should have ScatterDimensionNumbers set.";
539 TF_RET_CHECK(proto.called_computation_ids_size() == 1)
540 << "Scatter instruction should have 1 called computation but sees "
541 << proto.called_computation_ids_size();
542 auto scatter_dimension_numbers =
543 absl::make_unique<ScatterDimensionNumbers>(
544 proto.scatter_dimension_numbers());
545 instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
546 computations(0), *scatter_dimension_numbers);
547 break;
548 }
549 case HloOpcode::kIota:
550 TF_RET_CHECK(proto.dimensions_size() == 1)
551 << "Iota instruction should have 1 dimension but sees "
552 << proto.dimensions_size();
553 instruction = CreateIota(shape, proto.dimensions(0));
554 break;
555 case HloOpcode::kDot: {
556 TF_RET_CHECK(proto.has_dot_dimension_numbers())
557 << "Dot instruction should have dot_dimension_numbers.";
558 PrecisionConfig precision_config = proto.precision_config();
559 precision_config.mutable_operand_precision()->Resize(
560 proto.operand_ids_size(), PrecisionConfig::DEFAULT);
561 instruction = absl::make_unique<HloDotInstruction>(
562 shape, operands(0), operands(1), proto.dot_dimension_numbers(),
563 precision_config);
564 break;
565 }
566 case HloOpcode::kDomain: {
567 std::shared_ptr<const HloSharding> entry_hlo_sharding;
568 std::shared_ptr<const HloSharding> exit_hlo_sharding;
569 if (proto.has_domain_entry_sharding()) {
570 TF_ASSIGN_OR_RETURN(
571 HloSharding sharding,
572 HloSharding::FromProto(proto.domain_entry_sharding()));
573 entry_hlo_sharding = std::make_shared<const HloSharding>(sharding);
574 }
575 if (proto.has_domain_exit_sharding()) {
576 TF_ASSIGN_OR_RETURN(
577 HloSharding sharding,
578 HloSharding::FromProto(proto.domain_exit_sharding()));
579 exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
580 }
581 instruction = absl::make_unique<HloDomainInstruction>(
582 shape, operands(0),
583 absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
584 absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
585 break;
586 }
587 case HloOpcode::kGetDimensionSize:
588 TF_RET_CHECK(proto.dimensions_size() == 1);
589 instruction =
590 CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
591 break;
592 default: {
593 instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
594 for (const int64 operand_id : proto.operand_ids()) {
595 instruction->AppendOperand(instruction_map.at(operand_id));
596 }
597 if (instruction->opcode() != HloOpcode::kFusion) {
598 for (const int64 computation_id : proto.called_computation_ids()) {
599 instruction->called_computations_.push_back(
600 computation_map.at(computation_id));
601 }
602 }
603 TF_RET_CHECK(!proto.has_precision_config())
604 << instruction->opcode() << proto.DebugString();
605 TF_RET_CHECK(!proto.has_dot_dimension_numbers()) << instruction->opcode();
606 break;
607 }
608 }
609
610 for (const int64 predecessor_id : proto.control_predecessor_ids()) {
611 TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
612 << "No instruction with id " << predecessor_id;
613 TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
614 ->AddControlDependencyTo(instruction.get()));
615 }
616
617 TF_RET_CHECK(!proto.name().empty());
618 instruction->SetAndSanitizeName(proto.name());
619 instruction->metadata_ = proto.metadata();
620 instruction->backend_config_ = proto.backend_config();
621
622 TF_RET_CHECK(proto.id() >= 0)
623 << "Instruction with negative id: " << proto.id();
624 TF_RET_CHECK(proto.id() <= INT_MAX)
625 << "Instruction with id > INT_MAX: " << proto.id();
626 instruction->unique_id_ = proto.id();
627
628 if (proto.has_sharding()) {
629 TF_ASSIGN_OR_RETURN(const auto& sharding,
630 HloSharding::FromProto(proto.sharding()));
631 instruction->set_sharding(sharding);
632 }
633
634 return std::move(instruction);
635 }
636
CreateParameter(int64 parameter_number,const Shape & shape,const string & name)637 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateParameter(
638 int64 parameter_number, const Shape& shape, const string& name) {
639 return absl::make_unique<HloParameterInstruction>(parameter_number, shape,
640 name);
641 }
642
CreateTrace(const string & tag,HloInstruction * operand)643 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
644 const string& tag, HloInstruction* operand) {
645 return absl::make_unique<HloTraceInstruction>(tag, operand);
646 }
647
CreateConstant(Literal literal)648 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
649 Literal literal) {
650 return absl::make_unique<HloConstantInstruction>(std::move(literal));
651 }
652
CreateIota(const Shape & shape,int64 iota_dimension)653 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota(
654 const Shape& shape, int64 iota_dimension) {
655 return absl::make_unique<HloIotaInstruction>(shape, iota_dimension);
656 }
657
658 /* static */ std::unique_ptr<HloInstruction>
CreateGetTupleElement(const Shape & shape,HloInstruction * operand,int64 index)659 HloInstruction::CreateGetTupleElement(const Shape& shape,
660 HloInstruction* operand, int64 index) {
661 return absl::make_unique<HloGetTupleElementInstruction>(shape, operand,
662 index);
663 }
664
CreateRng(const Shape & shape,RandomDistribution distribution,absl::Span<HloInstruction * const> parameters)665 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRng(
666 const Shape& shape, RandomDistribution distribution,
667 absl::Span<HloInstruction* const> parameters) {
668 return absl::make_unique<HloRngInstruction>(shape, distribution, parameters);
669 }
670
CreateNary(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)671 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateNary(
672 const Shape& shape, HloOpcode opcode,
673 absl::Span<HloInstruction* const> operands) {
674 if (opcode == HloOpcode::kCopy) {
675 // It is impossible to copy an opaque shape, we don't know how big it is.
676 CHECK(!shape.IsOpaque());
677 }
678 auto instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
679 for (auto operand : operands) {
680 instruction->AppendOperand(operand);
681 }
682 return instruction;
683 }
684
CreateUnary(const Shape & shape,HloOpcode opcode,HloInstruction * operand)685 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateUnary(
686 const Shape& shape, HloOpcode opcode, HloInstruction* operand) {
687 // Only certain opcodes are supported with CreateUnary: opcodes of unary
688 // instructions with no auxiliary fields.
689 switch (opcode) {
690 case HloOpcode::kAbs:
691 case HloOpcode::kRoundNearestAfz:
692 case HloOpcode::kBitcast:
693 case HloOpcode::kCeil:
694 case HloOpcode::kCopy:
695 case HloOpcode::kCos:
696 case HloOpcode::kClz:
697 case HloOpcode::kExp:
698 case HloOpcode::kExpm1:
699 case HloOpcode::kFloor:
700 case HloOpcode::kImag:
701 case HloOpcode::kIsFinite:
702 case HloOpcode::kLog:
703 case HloOpcode::kLog1p:
704 case HloOpcode::kNot:
705 case HloOpcode::kNegate:
706 case HloOpcode::kReal:
707 case HloOpcode::kRsqrt:
708 case HloOpcode::kSign:
709 case HloOpcode::kSin:
710 case HloOpcode::kSqrt:
711 case HloOpcode::kTanh:
712 break;
713 default:
714 LOG(FATAL) << "Invalid unary instruction opcode "
715 << HloOpcodeString(opcode);
716 }
717 return CreateNary(shape, opcode, {operand});
718 }
719
CreateBinary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs)720 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBinary(
721 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
722 HloInstruction* rhs) {
723 // Only certain opcodes are supported with CreateBinary: opcodes of binary
724 // instructions with no auxiliary fields.
725 switch (opcode) {
726 case HloOpcode::kAdd:
727 case HloOpcode::kAtan2:
728 case HloOpcode::kDivide:
729 case HloOpcode::kComplex:
730 case HloOpcode::kMaximum:
731 case HloOpcode::kMinimum:
732 case HloOpcode::kMultiply:
733 case HloOpcode::kPower:
734 case HloOpcode::kRemainder:
735 case HloOpcode::kSubtract:
736 case HloOpcode::kAnd:
737 case HloOpcode::kOr:
738 case HloOpcode::kXor:
739 case HloOpcode::kShiftLeft:
740 case HloOpcode::kShiftRightArithmetic:
741 case HloOpcode::kShiftRightLogical:
742 break;
743 default:
744 LOG(FATAL) << "Invalid binary instruction opcode "
745 << HloOpcodeString(opcode);
746 }
747 return CreateNary(shape, opcode, {lhs, rhs});
748 }
749
CreateTernary(const Shape & shape,HloOpcode opcode,HloInstruction * lhs,HloInstruction * rhs,HloInstruction * ehs)750 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTernary(
751 const Shape& shape, HloOpcode opcode, HloInstruction* lhs,
752 HloInstruction* rhs, HloInstruction* ehs) {
753 // Only certain opcodes are supported with CreateTernary: opcodes of ternary
754 // instructions with no auxiliary fields.
755 switch (opcode) {
756 case HloOpcode::kClamp:
757 case HloOpcode::kSelect:
758 case HloOpcode::kTupleSelect:
759 break;
760 default:
761 LOG(FATAL) << "Invalid ternary instruction opcode "
762 << HloOpcodeString(opcode);
763 }
764 return CreateNary(shape, opcode, {lhs, rhs, ehs});
765 }
766
CreateVariadic(const Shape & shape,HloOpcode opcode,absl::Span<HloInstruction * const> operands)767 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateVariadic(
768 const Shape& shape, HloOpcode opcode,
769 absl::Span<HloInstruction* const> operands) {
770 CHECK_EQ(HloOpcode::kTuple, opcode);
771 return CreateNary(shape, opcode, operands);
772 }
773
CreateMap(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * map_computation)774 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateMap(
775 const Shape& shape, absl::Span<HloInstruction* const> operands,
776 HloComputation* map_computation) {
777 return absl::make_unique<HloMapInstruction>(shape, operands, map_computation);
778 }
779
CreateConvolve(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,int64 feature_group_count,int64 batch_group_count,const Window & window,const ConvolutionDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)780 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvolve(
781 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
782 int64 feature_group_count, int64 batch_group_count, const Window& window,
783 const ConvolutionDimensionNumbers& dimension_numbers,
784 const PrecisionConfig& precision_config) {
785 return absl::make_unique<HloConvolutionInstruction>(
786 shape, lhs, rhs, feature_group_count, batch_group_count, window,
787 dimension_numbers, precision_config);
788 }
789
CreateFft(const Shape & shape,HloInstruction * operand,FftType fft_type,absl::Span<const int64> fft_length)790 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFft(
791 const Shape& shape, HloInstruction* operand, FftType fft_type,
792 absl::Span<const int64> fft_length) {
793 return absl::make_unique<HloFftInstruction>(shape, operand, fft_type,
794 fft_length);
795 }
796
CreateCompare(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,ComparisonDirection direction)797 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
798 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
799 ComparisonDirection direction) {
800 return absl::make_unique<HloCompareInstruction>(shape, lhs, rhs, direction);
801 }
802
803 /* static */ std::unique_ptr<HloInstruction>
CreateTriangularSolve(const Shape & shape,HloInstruction * a,HloInstruction * b,const TriangularSolveOptions & options)804 HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
805 HloInstruction* b,
806 const TriangularSolveOptions& options) {
807 return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
808 }
809
CreateCholesky(const Shape & shape,HloInstruction * a,const CholeskyOptions & options)810 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
811 const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
812 return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
813 }
814
CreateDot(const Shape & shape,HloInstruction * lhs,HloInstruction * rhs,const DotDimensionNumbers & dimension_numbers,const PrecisionConfig & precision_config)815 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
816 const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
817 const DotDimensionNumbers& dimension_numbers,
818 const PrecisionConfig& precision_config) {
819 return absl::make_unique<HloDotInstruction>(
820 shape, lhs, rhs, dimension_numbers, precision_config);
821 }
822
823 /* static */ std::unique_ptr<HloInstruction>
CreateReducePrecision(const Shape & shape,HloInstruction * operand,const int exponent_bits,const int mantissa_bits)824 HloInstruction::CreateReducePrecision(const Shape& shape,
825 HloInstruction* operand,
826 const int exponent_bits,
827 const int mantissa_bits) {
828 return absl::make_unique<HloReducePrecisionInstruction>(
829 shape, operand, exponent_bits, mantissa_bits);
830 }
831
CreateAllReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * reduce_computation,const std::vector<ReplicaGroup> & replica_groups,absl::string_view barrier,const absl::optional<int64> & all_reduce_id)832 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllReduce(
833 const Shape& shape, absl::Span<HloInstruction* const> operands,
834 HloComputation* reduce_computation,
835 const std::vector<ReplicaGroup>& replica_groups, absl::string_view barrier,
836 const absl::optional<int64>& all_reduce_id) {
837 return absl::make_unique<HloAllReduceInstruction>(
838 shape, operands, reduce_computation, replica_groups, barrier,
839 all_reduce_id);
840 }
841
CreateAllToAll(const Shape & shape,absl::Span<HloInstruction * const> operands,const std::vector<ReplicaGroup> & replica_groups)842 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAllToAll(
843 const Shape& shape, absl::Span<HloInstruction* const> operands,
844 const std::vector<ReplicaGroup>& replica_groups) {
845 return absl::make_unique<HloAllToAllInstruction>(shape, operands,
846 replica_groups);
847 }
848
849 /* static */ std::unique_ptr<HloInstruction>
CreateCollectivePermute(const Shape & shape,HloInstruction * operand,const std::vector<std::pair<int64,int64>> & source_target_pairs)850 HloInstruction::CreateCollectivePermute(
851 const Shape& shape, HloInstruction* operand,
852 const std::vector<std::pair<int64, int64>>& source_target_pairs) {
853 return absl::make_unique<HloCollectivePermuteInstruction>(
854 shape, operand, source_target_pairs);
855 }
856
CreateReplicaId()857 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReplicaId() {
858 return absl::WrapUnique(
859 new HloInstruction(HloOpcode::kReplicaId, ShapeUtil::MakeShape(U32, {})));
860 }
861
CreateInfeed(const Shape & infeed_shape,HloInstruction * token_operand,const string & config)862 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
863 const Shape& infeed_shape, HloInstruction* token_operand,
864 const string& config) {
865 return absl::make_unique<HloInfeedInstruction>(infeed_shape, token_operand,
866 config);
867 }
868
CreateOutfeed(const Shape & outfeed_shape,HloInstruction * operand,HloInstruction * token_operand,absl::string_view outfeed_config)869 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateOutfeed(
870 const Shape& outfeed_shape, HloInstruction* operand,
871 HloInstruction* token_operand, absl::string_view outfeed_config) {
872 return absl::make_unique<HloOutfeedInstruction>(
873 outfeed_shape, operand, token_operand, outfeed_config);
874 }
875
CreateSend(HloInstruction * operand,HloInstruction * token,int64 channel_id,bool is_host_transfer)876 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
877 HloInstruction* operand, HloInstruction* token, int64 channel_id,
878 bool is_host_transfer) {
879 return absl::make_unique<HloSendInstruction>(operand, token, channel_id,
880 is_host_transfer);
881 }
882
CreateSendDone(HloInstruction * operand,bool is_host_transfer)883 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
884 HloInstruction* operand, bool is_host_transfer) {
885 auto send_operand = DynCast<HloSendInstruction>(operand);
886 CHECK(send_operand != nullptr)
887 << "SendDone must take the context operand from Send";
888 return absl::make_unique<HloSendDoneInstruction>(send_operand,
889 is_host_transfer);
890 }
891
CreateRecv(const Shape & shape,HloInstruction * token,int64 channel_id,bool is_host_transfer)892 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
893 const Shape& shape, HloInstruction* token, int64 channel_id,
894 bool is_host_transfer) {
895 return absl::make_unique<HloRecvInstruction>(shape, token, channel_id,
896 is_host_transfer);
897 }
898
CreateRecvDone(HloInstruction * operand,bool is_host_transfer)899 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
900 HloInstruction* operand, bool is_host_transfer) {
901 auto recv_operand = DynCast<HloRecvInstruction>(operand);
902 CHECK(recv_operand != nullptr)
903 << "RecvDone must take the context operand from Recv";
904 return absl::make_unique<HloRecvDoneInstruction>(recv_operand,
905 is_host_transfer);
906 }
907
CreateReverse(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)908 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
909 const Shape& shape, HloInstruction* operand,
910 absl::Span<const int64> dimensions) {
911 return absl::make_unique<HloReverseInstruction>(shape, operand, dimensions);
912 }
913
CreateAfterAll(absl::Span<HloInstruction * const> operands)914 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll(
915 absl::Span<HloInstruction* const> operands) {
916 CHECK(!operands.empty());
917 auto instruction = absl::WrapUnique(
918 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
919 for (auto operand : operands) {
920 instruction->AppendOperand(operand);
921 }
922 return instruction;
923 }
924
CreateToken()925 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() {
926 return absl::WrapUnique(
927 new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape()));
928 }
929
930 /* static */ std::unique_ptr<HloInstruction>
CreateAddDependency(HloInstruction * data_operand,HloInstruction * token_operand)931 HloInstruction::CreateAddDependency(HloInstruction* data_operand,
932 HloInstruction* token_operand) {
933 auto instruction = absl::WrapUnique(
934 new HloInstruction(HloOpcode::kAddDependency, data_operand->shape()));
935 instruction->AppendOperand(data_operand);
936 instruction->AppendOperand(token_operand);
937 return instruction;
938 }
939
CreateWhile(const Shape & shape,HloComputation * condition,HloComputation * body,HloInstruction * init)940 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile(
941 const Shape& shape, HloComputation* condition, HloComputation* body,
942 HloInstruction* init) {
943 auto instruction =
944 absl::WrapUnique(new HloInstruction(HloOpcode::kWhile, shape));
945 instruction->AppendOperand(init);
946 // Body comes before condition computation in the vector.
947 instruction->called_computations_.push_back(body);
948 instruction->called_computations_.push_back(condition);
949 return instruction;
950 }
951
CreateConditional(const Shape & shape,HloInstruction * pred,HloInstruction * true_computation_arg,HloComputation * true_computation,HloInstruction * false_computation_arg,HloComputation * false_computation)952 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
953 const Shape& shape, HloInstruction* pred,
954 HloInstruction* true_computation_arg, HloComputation* true_computation,
955 HloInstruction* false_computation_arg, HloComputation* false_computation) {
956 auto instruction =
957 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
958 instruction->AppendOperand(pred);
959 instruction->AppendOperand(true_computation_arg);
960 instruction->AppendOperand(false_computation_arg);
961 // In called_computations_, the index of true_computation must be 0 and that
962 // of false computation must be 1, as defined by kTrueComputationIndex and
963 // kFalseComputationIndex.
964 instruction->called_computations_.push_back(true_computation);
965 instruction->called_computations_.push_back(false_computation);
966 return instruction;
967 }
968
CreateConditional(const Shape & shape,HloInstruction * branch_index,absl::Span<HloComputation * const> branch_computations,absl::Span<HloInstruction * const> branch_computation_args)969 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConditional(
970 const Shape& shape, HloInstruction* branch_index,
971 absl::Span<HloComputation* const> branch_computations,
972 absl::Span<HloInstruction* const> branch_computation_args) {
973 auto instruction =
974 absl::WrapUnique(new HloInstruction(HloOpcode::kConditional, shape));
975 instruction->AppendOperand(branch_index);
976 CHECK_EQ(branch_computations.size(), branch_computation_args.size());
977 for (int i = 0; i < branch_computations.size(); ++i) {
978 instruction->called_computations_.push_back(branch_computations[i]);
979 instruction->AppendOperand(branch_computation_args[i]);
980 }
981 return instruction;
982 }
983
CreateSlice(const Shape & shape,HloInstruction * operand,absl::Span<const int64> start_indices,absl::Span<const int64> limit_indices,absl::Span<const int64> strides)984 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSlice(
985 const Shape& shape, HloInstruction* operand,
986 absl::Span<const int64> start_indices,
987 absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
988 return absl::make_unique<HloSliceInstruction>(shape, operand, start_indices,
989 limit_indices, strides);
990 }
991
CreateDynamicSlice(const Shape & shape,HloInstruction * operand,absl::Span<HloInstruction * const> start_indices,absl::Span<const int64> slice_sizes)992 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDynamicSlice(
993 const Shape& shape, HloInstruction* operand,
994 absl::Span<HloInstruction* const> start_indices,
995 absl::Span<const int64> slice_sizes) {
996 return absl::make_unique<HloDynamicSliceInstruction>(
997 shape, operand, start_indices, slice_sizes);
998 }
999
1000 /* static */ std::unique_ptr<HloInstruction>
CreateDynamicUpdateSlice(const Shape & shape,HloInstruction * operand,HloInstruction * update,absl::Span<HloInstruction * const> start_indices)1001 HloInstruction::CreateDynamicUpdateSlice(
1002 const Shape& shape, HloInstruction* operand, HloInstruction* update,
1003 absl::Span<HloInstruction* const> start_indices) {
1004 return absl::make_unique<HloDynamicUpdateSliceInstruction>(
1005 shape, operand, update, start_indices);
1006 }
1007
CreateConcatenate(const Shape & shape,absl::Span<HloInstruction * const> operands,int64 dimension)1008 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConcatenate(
1009 const Shape& shape, absl::Span<HloInstruction* const> operands,
1010 int64 dimension) {
1011 return absl::make_unique<HloConcatenateInstruction>(shape, operands,
1012 dimension);
1013 }
1014
CreateConvert(const Shape & shape,HloInstruction * operand)1015 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConvert(
1016 const Shape& shape, HloInstruction* operand) {
1017 auto instruction =
1018 absl::WrapUnique(new HloInstruction(HloOpcode::kConvert, shape));
1019 instruction->AppendOperand(operand);
1020 return instruction;
1021 }
1022
1023 /* static */ std::unique_ptr<HloInstruction>
CreateBitcastConvert(const Shape & shape,HloInstruction * operand)1024 HloInstruction::CreateBitcastConvert(const Shape& shape,
1025 HloInstruction* operand) {
1026 auto instruction =
1027 absl::WrapUnique(new HloInstruction(HloOpcode::kBitcastConvert, shape));
1028 instruction->AppendOperand(operand);
1029 return instruction;
1030 }
1031
CreateReduce(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1032 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1033 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1034 absl::Span<const int64> dimensions_to_reduce,
1035 HloComputation* reduce_computation) {
1036 auto instruction = absl::WrapUnique(new HloReduceInstruction(
1037 shape, {operand, init_value}, dimensions_to_reduce, reduce_computation));
1038 return std::move(instruction);
1039 }
1040
CreateReduce(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::Span<HloInstruction * const> init_values,absl::Span<const int64> dimensions_to_reduce,HloComputation * reduce_computation)1041 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduce(
1042 const Shape& shape, absl::Span<HloInstruction* const> operands,
1043 absl::Span<HloInstruction* const> init_values,
1044 absl::Span<const int64> dimensions_to_reduce,
1045 HloComputation* reduce_computation) {
1046 std::vector<HloInstruction*> all_args;
1047 all_args.reserve(operands.size() * 2);
1048 all_args.insert(all_args.end(), operands.begin(), operands.end());
1049 all_args.insert(all_args.end(), init_values.begin(), init_values.end());
1050 return absl::make_unique<HloReduceInstruction>(
1051 shape, all_args, dimensions_to_reduce, reduce_computation);
1052 }
1053
CreateReduceWindow(const Shape & shape,HloInstruction * operand,HloInstruction * init_value,const Window & window,HloComputation * reduce_computation)1054 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
1055 const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
1056 const Window& window, HloComputation* reduce_computation) {
1057 return absl::make_unique<HloReduceWindowInstruction>(
1058 shape, operand, init_value, window, reduce_computation);
1059 }
1060
1061 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormTraining(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,float epsilon,int64 feature_index)1062 HloInstruction::CreateBatchNormTraining(const Shape& shape,
1063 HloInstruction* operand,
1064 HloInstruction* scale,
1065 HloInstruction* offset, float epsilon,
1066 int64 feature_index) {
1067 return absl::make_unique<HloBatchNormTrainingInstruction>(
1068 shape, operand, scale, offset, epsilon, feature_index);
1069 }
1070
1071 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormInference(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * offset,HloInstruction * mean,HloInstruction * variance,float epsilon,int64 feature_index)1072 HloInstruction::CreateBatchNormInference(
1073 const Shape& shape, HloInstruction* operand, HloInstruction* scale,
1074 HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
1075 float epsilon, int64 feature_index) {
1076 return absl::make_unique<HloBatchNormInferenceInstruction>(
1077 shape, operand, scale, offset, mean, variance, epsilon, feature_index);
1078 }
1079
1080 /* static */ std::unique_ptr<HloInstruction>
CreateBatchNormGrad(const Shape & shape,HloInstruction * operand,HloInstruction * scale,HloInstruction * mean,HloInstruction * variance,HloInstruction * grad_output,float epsilon,int64 feature_index)1081 HloInstruction::CreateBatchNormGrad(const Shape& shape, HloInstruction* operand,
1082 HloInstruction* scale, HloInstruction* mean,
1083 HloInstruction* variance,
1084 HloInstruction* grad_output, float epsilon,
1085 int64 feature_index) {
1086 return absl::make_unique<HloBatchNormGradInstruction>(
1087 shape, operand, scale, mean, variance, grad_output, epsilon,
1088 feature_index);
1089 }
1090
1091 /* static */ std::unique_ptr<HloInstruction>
CreateSelectAndScatter(const Shape & shape,HloInstruction * operand,HloComputation * select,const Window & window,HloInstruction * source,HloInstruction * init_value,HloComputation * scatter)1092 HloInstruction::CreateSelectAndScatter(
1093 const Shape& shape, HloInstruction* operand, HloComputation* select,
1094 const Window& window, HloInstruction* source, HloInstruction* init_value,
1095 HloComputation* scatter) {
1096 return absl::make_unique<HloSelectAndScatterInstruction>(
1097 shape, operand, select, window, source, init_value, scatter);
1098 }
1099
CreateBroadcast(const Shape & shape,HloInstruction * operand,absl::Span<const int64> broadcast_dimensions)1100 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateBroadcast(
1101 const Shape& shape, HloInstruction* operand,
1102 absl::Span<const int64> broadcast_dimensions) {
1103 return absl::make_unique<HloBroadcastInstruction>(shape, operand,
1104 broadcast_dimensions);
1105 }
1106
1107 /* static */ std::unique_ptr<HloInstruction>
CreateGetDimensionSize(const Shape & shape,HloInstruction * operand,int64 dimension)1108 HloInstruction::CreateGetDimensionSize(const Shape& shape,
1109 HloInstruction* operand,
1110 int64 dimension) {
1111 return absl::make_unique<HloGetDimensionSizeInstruction>(shape, operand,
1112 dimension);
1113 }
1114
1115 /* static */ std::unique_ptr<HloInstruction>
CreateBroadcastSequence(const Shape & output_shape,HloInstruction * operand,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & adder)1116 HloInstruction::CreateBroadcastSequence(
1117 const Shape& output_shape, HloInstruction* operand,
1118 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
1119 adder) {
1120 CHECK(ShapeUtil::IsScalar(operand->shape()) ||
1121 operand->shape().rank() == output_shape.rank());
1122 Shape broadcast_shape = ShapeUtil::ChangeElementType(
1123 output_shape, operand->shape().element_type());
1124 // Do explicit broadcast for scalar.
1125 if (ShapeUtil::IsScalar(operand->shape())) {
1126 auto broadcast =
1127 HloInstruction::CreateBroadcast(broadcast_shape, operand, {});
1128 broadcast->set_metadata(operand->metadata());
1129 if (operand->has_sharding()) {
1130 broadcast->set_sharding(operand->sharding());
1131 }
1132 return broadcast;
1133 }
1134 // Do explicit broadcast for degenerate broadcast.
1135 std::vector<int64> broadcast_dimensions;
1136 std::vector<int64> reshaped_dimensions;
1137 for (int i = 0; i < operand->shape().rank(); i++) {
1138 if (operand->shape().dimensions(i) == output_shape.dimensions(i)) {
1139 broadcast_dimensions.push_back(i);
1140 reshaped_dimensions.push_back(operand->shape().dimensions(i));
1141 } else {
1142 CHECK_EQ(operand->shape().dimensions(i), 1)
1143 << "An explicit broadcast sequence requires the broadcasted "
1144 "dimensions to be trivial; operand: "
1145 << operand->ToString() << "; output_shape: " << output_shape;
1146 }
1147 }
1148 // Eliminate the size one dimensions.
1149 HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape(
1150 ShapeUtil::MakeShape(operand->shape().element_type(),
1151 reshaped_dimensions),
1152 operand));
1153 reshaped_operand->set_metadata(operand->metadata());
1154 if (operand->has_sharding()) {
1155 reshaped_operand->set_sharding(operand->sharding());
1156 }
1157 // Broadcast 'reshape' up to the larger size.
1158 auto broadcast = HloInstruction::CreateBroadcast(
1159 broadcast_shape, reshaped_operand, broadcast_dimensions);
1160 broadcast->set_metadata(operand->metadata());
1161 if (operand->has_sharding()) {
1162 broadcast->set_sharding(operand->sharding());
1163 }
1164 return broadcast;
1165 }
1166
CreatePad(const Shape & shape,HloInstruction * operand,HloInstruction * padding_value,const PaddingConfig & padding_config)1167 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreatePad(
1168 const Shape& shape, HloInstruction* operand, HloInstruction* padding_value,
1169 const PaddingConfig& padding_config) {
1170 return absl::make_unique<HloPadInstruction>(shape, operand, padding_value,
1171 padding_config);
1172 }
1173
CreateReshape(const Shape & shape,HloInstruction * operand)1174 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
1175 const Shape& shape, HloInstruction* operand) {
1176 CHECK_EQ(ShapeUtil::ElementsIn(shape),
1177 ShapeUtil::ElementsIn(operand->shape()))
1178 << "shape: " << ShapeUtil::HumanString(shape)
1179 << " operand: " << ShapeUtil::HumanString(operand->shape());
1180 auto instruction =
1181 absl::WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
1182 instruction->AppendOperand(operand);
1183 return instruction;
1184 }
1185
CreateTranspose(const Shape & shape,HloInstruction * operand,absl::Span<const int64> dimensions)1186 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTranspose(
1187 const Shape& shape, HloInstruction* operand,
1188 absl::Span<const int64> dimensions) {
1189 return absl::make_unique<HloTransposeInstruction>(shape, operand, dimensions);
1190 }
1191
CreateSort(const Shape & shape,int64 dimension,absl::Span<HloInstruction * const> operands,HloComputation * compare,bool is_stable)1192 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
1193 const Shape& shape, int64 dimension,
1194 absl::Span<HloInstruction* const> operands, HloComputation* compare,
1195 bool is_stable) {
1196 return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
1197 compare, is_stable);
1198 }
1199
CreateFusion(const Shape & shape,FusionKind fusion_kind,HloInstruction * fused_root)1200 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1201 const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
1202 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
1203 fused_root);
1204 }
1205
CreateFusion(const Shape & shape,FusionKind fusion_kind,absl::Span<HloInstruction * const> operands,HloComputation * fusion_computation)1206 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
1207 const Shape& shape, FusionKind fusion_kind,
1208 absl::Span<HloInstruction* const> operands,
1209 HloComputation* fusion_computation) {
1210 return absl::make_unique<HloFusionInstruction>(shape, fusion_kind, operands,
1211 fusion_computation);
1212 }
1213
set_single_sharding(const HloSharding & sharding)1214 void HloInstruction::set_single_sharding(const HloSharding& sharding) {
1215 CHECK(!sharding.IsTuple()) << sharding;
1216 if (shape().IsTuple()) {
1217 set_sharding(HloSharding::Tuple(sharding.GetAsShapeTree(shape())));
1218 } else {
1219 set_sharding(sharding);
1220 }
1221 }
1222
SetupDerivedInstruction(HloInstruction * derived_instruction) const1223 void HloInstruction::SetupDerivedInstruction(
1224 HloInstruction* derived_instruction) const {
1225 if (sharding_ != nullptr && ShapeUtil::CompatibleIgnoringElementType(
1226 shape_, derived_instruction->shape())) {
1227 // Only copy sharding if the shape of the two instruction is compatible
1228 // because copying it between differently shaped instructions can produce
1229 // invalid shardings.
1230 derived_instruction->set_sharding(*sharding_);
1231 } else {
1232 derived_instruction->clear_sharding();
1233 }
1234 derived_instruction->set_metadata(metadata_);
1235 }
1236
HasSideEffectNoRecurse() const1237 bool HloInstruction::HasSideEffectNoRecurse() const {
1238 switch (opcode_) {
1239 case HloOpcode::kSend:
1240 case HloOpcode::kSendDone:
1241 case HloOpcode::kRecv:
1242 case HloOpcode::kRecvDone:
1243 case HloOpcode::kRng:
1244 case HloOpcode::kInfeed:
1245 case HloOpcode::kOutfeed:
1246 case HloOpcode::kTrace:
1247 return true;
1248 case HloOpcode::kAllReduce:
1249 return all_reduce_id().has_value();
1250 default:
1251 return false;
1252 }
1253 }
1254
HasSideEffect() const1255 bool HloInstruction::HasSideEffect() const {
1256 if (HasSideEffectNoRecurse()) {
1257 return true;
1258 }
1259 // Check if any of the called computations has a side effect.
1260 for (const auto& computation : called_computations()) {
1261 if (computation->HasSideEffect()) {
1262 return true;
1263 }
1264 }
1265 return false;
1266 }
1267
CreateCall(const Shape & shape,absl::Span<HloInstruction * const> operands,HloComputation * computation)1268 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCall(
1269 const Shape& shape, absl::Span<HloInstruction* const> operands,
1270 HloComputation* computation) {
1271 std::unique_ptr<HloInstruction> instruction =
1272 absl::WrapUnique(new HloInstruction(HloOpcode::kCall, shape));
1273 for (auto operand : operands) {
1274 instruction->AppendOperand(operand);
1275 }
1276 instruction->called_computations_.push_back(computation);
1277 return instruction;
1278 }
1279
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::string_view opaque)1280 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1281 const Shape& shape, absl::Span<HloInstruction* const> operands,
1282 absl::string_view custom_call_target, absl::string_view opaque) {
1283 return absl::make_unique<HloCustomCallInstruction>(
1284 shape, operands, custom_call_target, opaque);
1285 }
1286
CreateCustomCall(const Shape & shape,absl::Span<HloInstruction * const> operands,absl::string_view custom_call_target,absl::Span<const Shape> operand_shapes_with_layout,absl::string_view opaque)1287 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
1288 const Shape& shape, absl::Span<HloInstruction* const> operands,
1289 absl::string_view custom_call_target,
1290 absl::Span<const Shape> operand_shapes_with_layout,
1291 absl::string_view opaque) {
1292 return absl::make_unique<HloCustomCallInstruction>(
1293 shape, operands, custom_call_target, opaque, operand_shapes_with_layout);
1294 }
1295
CreateTuple(absl::Span<HloInstruction * const> elements)1296 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
1297 absl::Span<HloInstruction* const> elements) {
1298 std::vector<Shape> element_shapes;
1299 for (auto element : elements) {
1300 element_shapes.push_back(element->shape());
1301 }
1302 Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
1303 return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
1304 }
1305
CreateGather(const Shape & shape,HloInstruction * operand,HloInstruction * start_indices,const GatherDimensionNumbers & gather_dim_numbers,absl::Span<const int64> slice_sizes)1306 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
1307 const Shape& shape, HloInstruction* operand, HloInstruction* start_indices,
1308 const GatherDimensionNumbers& gather_dim_numbers,
1309 absl::Span<const int64> slice_sizes) {
1310 return absl::make_unique<HloGatherInstruction>(
1311 shape, operand, start_indices, gather_dim_numbers, slice_sizes);
1312 }
1313
CreateScatter(const Shape & shape,HloInstruction * operand,HloInstruction * scatter_indices,HloInstruction * updates,HloComputation * update_computation,const ScatterDimensionNumbers & scatter_dim_numbers)1314 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
1315 const Shape& shape, HloInstruction* operand,
1316 HloInstruction* scatter_indices, HloInstruction* updates,
1317 HloComputation* update_computation,
1318 const ScatterDimensionNumbers& scatter_dim_numbers) {
1319 return absl::make_unique<HloScatterInstruction>(
1320 shape, operand, scatter_indices, updates, update_computation,
1321 scatter_dim_numbers);
1322 }
1323
CreateDomain(const Shape & shape,HloInstruction * operand,std::unique_ptr<DomainMetadata> operand_side_metadata,std::unique_ptr<DomainMetadata> user_side_metadata)1324 /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
1325 const Shape& shape, HloInstruction* operand,
1326 std::unique_ptr<DomainMetadata> operand_side_metadata,
1327 std::unique_ptr<DomainMetadata> user_side_metadata) {
1328 return absl::make_unique<HloDomainInstruction>(
1329 shape, operand, std::move(operand_side_metadata),
1330 std::move(user_side_metadata));
1331 }
1332
CloneWithNewOperands(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context) const1333 std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
1334 const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1335 HloCloneContext* context) const {
1336 VLOG(3) << "CloneWithNewOperands:\n " << ToString();
1337 VLOG(3) << " new operands:";
1338 for (const HloInstruction* new_operand : new_operands) {
1339 VLOG(3) << " %" << new_operand->name();
1340 }
1341
1342 std::unique_ptr<HloInstruction> clone;
1343 // Explicitly call the factory for the instruction type. This is more robust
1344 // in the face of code changes than copying fields explicitly. This also
1345 // properly sets the user fields of the operands.
1346 switch (opcode_) {
1347 // Ops migrated to subclasses.
1348 // TODO(b/80131774): Remove this switch when migration is complete.
1349 case HloOpcode::kBatchNormTraining:
1350 case HloOpcode::kBatchNormInference:
1351 case HloOpcode::kBatchNormGrad:
1352 case HloOpcode::kFft:
1353 case HloOpcode::kCompare:
1354 case HloOpcode::kSend:
1355 case HloOpcode::kSendDone:
1356 case HloOpcode::kRecv:
1357 case HloOpcode::kRecvDone:
1358 case HloOpcode::kReverse:
1359 case HloOpcode::kConcatenate:
1360 case HloOpcode::kReduce:
1361 case HloOpcode::kTranspose:
1362 case HloOpcode::kBroadcast:
1363 case HloOpcode::kMap:
1364 case HloOpcode::kSlice:
1365 case HloOpcode::kConstant:
1366 case HloOpcode::kTrace:
1367 case HloOpcode::kFusion:
1368 case HloOpcode::kRng:
1369 case HloOpcode::kParameter:
1370 case HloOpcode::kGetTupleElement:
1371 case HloOpcode::kReducePrecision:
1372 case HloOpcode::kAllReduce:
1373 case HloOpcode::kAllToAll:
1374 case HloOpcode::kCollectivePermute:
1375 case HloOpcode::kInfeed:
1376 case HloOpcode::kOutfeed:
1377 case HloOpcode::kConvolution:
1378 case HloOpcode::kCustomCall:
1379 case HloOpcode::kReduceWindow:
1380 case HloOpcode::kSelectAndScatter:
1381 case HloOpcode::kPad:
1382 case HloOpcode::kDynamicSlice:
1383 case HloOpcode::kSort:
1384 case HloOpcode::kGather:
1385 case HloOpcode::kScatter:
1386 case HloOpcode::kIota:
1387 case HloOpcode::kDot:
1388 case HloOpcode::kDomain:
1389 case HloOpcode::kGetDimensionSize:
1390 case HloOpcode::kTriangularSolve:
1391 case HloOpcode::kCholesky:
1392 clone = CloneWithNewOperandsImpl(shape, new_operands, context);
1393 break;
1394 // Unary ops.
1395 case HloOpcode::kAbs:
1396 case HloOpcode::kRoundNearestAfz:
1397 case HloOpcode::kBitcast:
1398 case HloOpcode::kCeil:
1399 case HloOpcode::kClz:
1400 case HloOpcode::kCopy:
1401 case HloOpcode::kCos:
1402 case HloOpcode::kExp:
1403 case HloOpcode::kExpm1:
1404 case HloOpcode::kImag:
1405 case HloOpcode::kIsFinite:
1406 case HloOpcode::kFloor:
1407 case HloOpcode::kLog:
1408 case HloOpcode::kLog1p:
1409 case HloOpcode::kNot:
1410 case HloOpcode::kNegate:
1411 case HloOpcode::kReal:
1412 case HloOpcode::kRsqrt:
1413 case HloOpcode::kSign:
1414 case HloOpcode::kSin:
1415 case HloOpcode::kSqrt:
1416 case HloOpcode::kTanh:
1417 CHECK_EQ(new_operands.size(), 1);
1418 clone = CreateUnary(shape, opcode_, new_operands[0]);
1419 break;
1420 // Binary ops.
1421 case HloOpcode::kAdd:
1422 case HloOpcode::kAtan2:
1423 case HloOpcode::kComplex:
1424 case HloOpcode::kDivide:
1425 case HloOpcode::kMultiply:
1426 case HloOpcode::kSubtract:
1427 case HloOpcode::kMaximum:
1428 case HloOpcode::kMinimum:
1429 case HloOpcode::kPower:
1430 case HloOpcode::kRemainder:
1431 case HloOpcode::kAnd:
1432 case HloOpcode::kOr:
1433 case HloOpcode::kXor:
1434 case HloOpcode::kShiftLeft:
1435 case HloOpcode::kShiftRightArithmetic:
1436 case HloOpcode::kShiftRightLogical:
1437 CHECK_EQ(new_operands.size(), 2);
1438 clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]);
1439 break;
1440 // Ternary ops.
1441 case HloOpcode::kClamp:
1442 case HloOpcode::kSelect:
1443 case HloOpcode::kTupleSelect:
1444 CHECK_EQ(new_operands.size(), 3);
1445 clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1],
1446 new_operands[2]);
1447 break;
1448 // Other supported ops.
1449 case HloOpcode::kCall:
1450 clone = CreateCall(shape, new_operands, to_apply());
1451 break;
1452 case HloOpcode::kConvert:
1453 CHECK_EQ(new_operands.size(), 1);
1454 clone = CreateConvert(shape, new_operands[0]);
1455 break;
1456 case HloOpcode::kBitcastConvert:
1457 CHECK_EQ(new_operands.size(), 1);
1458 clone = CreateBitcastConvert(shape, new_operands[0]);
1459 break;
1460 case HloOpcode::kReshape:
1461 CHECK_EQ(new_operands.size(), 1);
1462 clone = CreateReshape(shape, new_operands[0]);
1463 break;
1464 case HloOpcode::kDynamicUpdateSlice:
1465 clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1],
1466 new_operands.subspan(2));
1467 break;
1468 case HloOpcode::kTuple:
1469 clone = CreateTuple(new_operands);
1470 *clone->mutable_shape() = shape;
1471 break;
1472 case HloOpcode::kWhile:
1473 CHECK_EQ(new_operands.size(), 1);
1474 clone =
1475 CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
1476 break;
1477 case HloOpcode::kConditional:
1478 CHECK_EQ(new_operands.size(), branch_count() + 1);
1479 clone = CreateConditional(shape, new_operands[0],
1480 absl::MakeSpan(branch_computations()),
1481 new_operands.subspan(1));
1482 break;
1483 case HloOpcode::kAfterAll:
1484 if (new_operands.empty()) {
1485 clone = CreateToken();
1486 } else {
1487 clone = CreateAfterAll(new_operands);
1488 }
1489 break;
1490 case HloOpcode::kAddDependency:
1491 CHECK_EQ(new_operands.size(), 2);
1492 clone = CreateAddDependency(new_operands[0], new_operands[1]);
1493 break;
1494 case HloOpcode::kReplicaId:
1495 CHECK_EQ(new_operands.size(), 0);
1496 clone = CreateReplicaId();
1497 break;
1498 }
1499 // SetupDerivedInstruction will setup the precision_config_ field.
1500 SetupDerivedInstruction(clone.get());
1501 clone->set_parent(parent_);
1502 clone->set_raw_backend_config_string(backend_config_);
1503 if (context != nullptr) {
1504 context->MapInstruction(this, clone.get());
1505 clone->ReplaceCalledComputations([&](HloComputation* callee) {
1506 return callee->parent() != context->module()
1507 ? context->module()->DeepCloneComputation(callee, context)
1508 : callee;
1509 });
1510 }
1511 return clone;
1512 }
1513
~HloInstruction()1514 HloInstruction::~HloInstruction() {
1515 // Detach from operands. An instruction may be repeated as an operand. To
1516 // avoid calling RemoveUser twice on the same operand, check before remove.
1517 for (int64 operand_num = 0; operand_num < operand_count(); ++operand_num) {
1518 HloInstruction* operand = operands_[operand_num];
1519 if (operand == nullptr) {
1520 continue;
1521 }
1522 if (operand->user_set_.find(this) != operand->user_set_.end()) {
1523 operand->RemoveUser(this);
1524 }
1525 operands_[operand_num] = nullptr;
1526 }
1527
1528 // Update users. Set `nullptr` to the correpsonding operand slot for users.
1529 for (auto& user : this->users()) {
1530 for (int i = 0; i < user->operand_count(); ++i) {
1531 if (user->operands_[i] == this) {
1532 user->operands_[i] = nullptr;
1533 }
1534 }
1535 }
1536 }
1537
Clone(const string & suffix,HloCloneContext * context) const1538 std::unique_ptr<HloInstruction> HloInstruction::Clone(
1539 const string& suffix, HloCloneContext* context) const {
1540 std::unique_ptr<HloInstruction> clone =
1541 CloneWithNewOperands(shape_, operands_, context);
1542 if (suffix.empty()) {
1543 clone->name_ = name();
1544 } else {
1545 // If an instruction is cloned multiple times avoid names like
1546 // foo.suffix.suffix.suffix. Instead of repeating the suffix add a numeric
1547 // suffix. Specifically, the clone of foo.suffix is named foo.suffix2, the
1548 // clone of foo.suffix2 is named foo.suffix3 and so on.
1549 const string dot_suffix = "." + suffix;
1550 size_t index = name().rfind(dot_suffix);
1551 if (index == string::npos) {
1552 // Existing name does not include ".suffix".
1553 clone->name_ = name() + dot_suffix;
1554 } else {
1555 // Existing name includes ".suffix". Determine if substring after
1556 // ".suffix" is numeric and should be replaced with an incremented number.
1557 string after_suffix = name().substr(index + dot_suffix.size());
1558 if (after_suffix.empty()) {
1559 // Existing name ends in ".suffix". New name should end in ".suffix2".
1560 clone->name_ = name() + "2";
1561 } else {
1562 // If names ends with .suffix[0-9]+ then replace with a suffix with the
1563 // numeric value incremented.
1564 int64 numeric_suffix;
1565 if (absl::SimpleAtoi(after_suffix, &numeric_suffix)) {
1566 clone->name_ =
1567 StrCat(name().substr(0, index), dot_suffix, numeric_suffix + 1);
1568 } else {
1569 // Substring after ".suffix" is non-numeric.
1570 clone->name_ = name() + dot_suffix;
1571 }
1572 }
1573 }
1574 }
1575 return clone;
1576 }
1577
1578 std::pair<const HloInstruction*, ShapeIndex>
LatestNonGteAncestorAndIndex() const1579 HloInstruction::LatestNonGteAncestorAndIndex() const {
1580 const HloInstruction* hlo = this;
1581 ShapeIndex index;
1582 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1583 index.push_back(hlo->tuple_index());
1584 hlo = hlo->operand(0);
1585 }
1586
1587 // We built up index in the reverse order from what we want.
1588 std::reverse(index.begin(), index.end());
1589
1590 return {hlo, index};
1591 }
1592
LatestNonGteAncestor() const1593 const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
1594 const HloInstruction* hlo = this;
1595 while (hlo->opcode() == HloOpcode::kGetTupleElement) {
1596 hlo = hlo->operand(0);
1597 }
1598 return hlo;
1599 }
1600
operand(int64 i) const1601 const HloInstruction* HloInstruction::operand(int64 i) const {
1602 return operands_[i];
1603 }
1604
mutable_operand(int64 i)1605 HloInstruction* HloInstruction::mutable_operand(int64 i) {
1606 CHECK(operands_[i] != nullptr);
1607 return operands_[i];
1608 }
1609
operand_index(const HloInstruction * target) const1610 int64 HloInstruction::operand_index(const HloInstruction* target) const {
1611 for (int64 i = 0; i < operand_count(); ++i) {
1612 if (target == operand(i)) {
1613 return i;
1614 }
1615 }
1616 LOG(FATAL) << "target was not an operand: " << target->ToString();
1617 }
1618
unique_operands() const1619 HloInstruction::InstructionVector HloInstruction::unique_operands() const {
1620 InstructionVector unique;
1621 absl::flat_hash_set<const HloInstruction*> seen;
1622 for (HloInstruction* operand : operands()) {
1623 if (seen.insert(operand).second) {
1624 unique.push_back(operand);
1625 }
1626 }
1627 return unique;
1628 }
1629
AddControlDependencyTo(HloInstruction * instruction)1630 Status HloInstruction::AddControlDependencyTo(HloInstruction* instruction) {
1631 TF_RET_CHECK(instruction->parent() == parent());
1632 if (!absl::c_linear_search(control_successors_, instruction)) {
1633 control_successors_.push_back(instruction);
1634 TF_RET_CHECK(
1635 !absl::c_linear_search(instruction->control_predecessors_, this));
1636 instruction->control_predecessors_.push_back(this);
1637 }
1638 return Status::OK();
1639 }
1640
RemoveControlDependencyTo(HloInstruction * instruction)1641 Status HloInstruction::RemoveControlDependencyTo(HloInstruction* instruction) {
1642 TF_RET_CHECK(instruction->parent() == parent());
1643 TF_RETURN_IF_ERROR(EraseElementFromVector(&control_successors_, instruction));
1644 TF_RETURN_IF_ERROR(
1645 EraseElementFromVector(&instruction->control_predecessors_, this));
1646 return Status::OK();
1647 }
1648
DropAllControlDeps()1649 Status HloInstruction::DropAllControlDeps() {
1650 for (auto* ctrl_succ : control_successors_) {
1651 TF_RETURN_IF_ERROR(
1652 EraseElementFromVector(&ctrl_succ->control_predecessors_, this));
1653 }
1654 for (auto* ctrl_pred : control_predecessors_) {
1655 TF_RETURN_IF_ERROR(
1656 EraseElementFromVector(&ctrl_pred->control_successors_, this));
1657 }
1658 control_successors_.clear();
1659 control_predecessors_.clear();
1660 return Status::OK();
1661 }
1662
CopyAllControlDepsFrom(const HloInstruction * inst)1663 Status HloInstruction::CopyAllControlDepsFrom(const HloInstruction* inst) {
1664 for (auto* ctrl_pred : inst->control_predecessors()) {
1665 TF_RETURN_IF_ERROR(ctrl_pred->AddControlDependencyTo(this));
1666 }
1667
1668 for (auto* ctrl_succ : inst->control_successors()) {
1669 TF_RETURN_IF_ERROR(this->AddControlDependencyTo(ctrl_succ));
1670 }
1671
1672 return Status::OK();
1673 }
1674
AppendOperand(HloInstruction * operand)1675 void HloInstruction::AppendOperand(HloInstruction* operand) {
1676 operands_.push_back(operand);
1677 operand->AddUser(this);
1678 }
1679
RemoveOperandsAtAscendingIndices(absl::Span<const int> ascending_indices)1680 void HloInstruction::RemoveOperandsAtAscendingIndices(
1681 absl::Span<const int> ascending_indices) {
1682 if (ascending_indices.empty()) {
1683 return;
1684 }
1685 int next_index = 0;
1686 int removed_count = 0;
1687 for (int to_remove : ascending_indices) {
1688 while (next_index < to_remove) {
1689 operands_[next_index - removed_count] = operands_[next_index];
1690 ++next_index;
1691 }
1692 CHECK_LT(to_remove, operands_.size());
1693 ++removed_count;
1694 ++next_index;
1695 }
1696 while (next_index < operands_.size()) {
1697 operands_[next_index - removed_count] = operands_[next_index];
1698 ++next_index;
1699 }
1700 CHECK_EQ(removed_count, ascending_indices.size());
1701 operands_.resize(operands_.size() - removed_count);
1702 }
1703
AddUser(HloInstruction * user)1704 void HloInstruction::AddUser(HloInstruction* user) {
1705 if (!ContainsKey(user_set_, user)) {
1706 user_set_.insert(user);
1707 users_.push_back(user);
1708 }
1709 }
1710
HasConstantOperand() const1711 bool HloInstruction::HasConstantOperand() const {
1712 for (const HloInstruction* operand : operands_) {
1713 if (operand->IsConstant()) {
1714 return true;
1715 }
1716 }
1717 return false;
1718 }
1719
IdenticalSlowPath(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations) const1720 bool HloInstruction::IdenticalSlowPath(
1721 const HloInstruction& other,
1722 const std::function<bool(const HloComputation*, const HloComputation*)>&
1723 eq_computations) const {
1724 // Perform opcode specific checks.
1725 switch (opcode()) {
1726 // The result of these instructions only depend upon their opcode and
1727 // operands.
1728 case HloOpcode::kAbs:
1729 case HloOpcode::kAtan2:
1730 case HloOpcode::kAdd:
1731 case HloOpcode::kBitcast:
1732 case HloOpcode::kBitcastConvert:
1733 case HloOpcode::kCeil:
1734 case HloOpcode::kClamp:
1735 case HloOpcode::kClz:
1736 case HloOpcode::kComplex:
1737 case HloOpcode::kConvert:
1738 case HloOpcode::kCopy:
1739 case HloOpcode::kCos:
1740 case HloOpcode::kDivide:
1741 case HloOpcode::kDynamicUpdateSlice:
1742 case HloOpcode::kExp:
1743 case HloOpcode::kExpm1:
1744 case HloOpcode::kFloor:
1745 case HloOpcode::kImag:
1746 case HloOpcode::kIsFinite:
1747 case HloOpcode::kLog:
1748 case HloOpcode::kLog1p:
1749 case HloOpcode::kAnd:
1750 case HloOpcode::kNot:
1751 case HloOpcode::kOr:
1752 case HloOpcode::kXor:
1753 case HloOpcode::kMaximum:
1754 case HloOpcode::kMinimum:
1755 case HloOpcode::kMultiply:
1756 case HloOpcode::kNegate:
1757 case HloOpcode::kPower:
1758 case HloOpcode::kReal:
1759 case HloOpcode::kRemainder:
1760 case HloOpcode::kReshape:
1761 case HloOpcode::kReplicaId:
1762 case HloOpcode::kRoundNearestAfz:
1763 case HloOpcode::kRsqrt:
1764 case HloOpcode::kSelect:
1765 case HloOpcode::kShiftLeft:
1766 case HloOpcode::kShiftRightArithmetic:
1767 case HloOpcode::kShiftRightLogical:
1768 case HloOpcode::kSign:
1769 case HloOpcode::kSin:
1770 case HloOpcode::kSqrt:
1771 case HloOpcode::kSubtract:
1772 case HloOpcode::kTanh:
1773 case HloOpcode::kTuple:
1774 case HloOpcode::kTupleSelect:
1775 return true;
1776
1777 // This opcode has complex or special behavior so just return false.
1778 case HloOpcode::kAfterAll:
1779 case HloOpcode::kAddDependency:
1780 return false;
1781
1782 // Remaining instructions with special values.
1783 case HloOpcode::kCall:
1784 return eq_computations(to_apply(), other.to_apply());
1785 case HloOpcode::kConditional:
1786 for (int j = 0; j < branch_count(); ++j) {
1787 if (!eq_computations(branch_computation(j),
1788 other.branch_computation(j))) {
1789 return false;
1790 }
1791 }
1792 return true;
1793 case HloOpcode::kWhile:
1794 return (eq_computations(while_body(), other.while_body()) &&
1795 eq_computations(while_condition(), other.while_condition()));
1796
1797 // Ops migrated to subclasses should never come to this line.
1798 // TODO(b/80131774): Remove this switch when migration is complete.
1799 case HloOpcode::kBatchNormTraining:
1800 case HloOpcode::kBatchNormInference:
1801 case HloOpcode::kBatchNormGrad:
1802 case HloOpcode::kFft:
1803 case HloOpcode::kCompare:
1804 case HloOpcode::kSend:
1805 case HloOpcode::kSendDone:
1806 case HloOpcode::kRecv:
1807 case HloOpcode::kRecvDone:
1808 case HloOpcode::kReverse:
1809 case HloOpcode::kConcatenate:
1810 case HloOpcode::kReduce:
1811 case HloOpcode::kSort:
1812 case HloOpcode::kTranspose:
1813 case HloOpcode::kBroadcast:
1814 case HloOpcode::kMap:
1815 case HloOpcode::kSlice:
1816 case HloOpcode::kConstant:
1817 case HloOpcode::kIota:
1818 case HloOpcode::kTrace:
1819 case HloOpcode::kFusion:
1820 case HloOpcode::kRng:
1821 case HloOpcode::kParameter:
1822 case HloOpcode::kGetTupleElement:
1823 case HloOpcode::kReducePrecision:
1824 case HloOpcode::kInfeed:
1825 case HloOpcode::kOutfeed:
1826 case HloOpcode::kAllReduce:
1827 case HloOpcode::kAllToAll:
1828 case HloOpcode::kCollectivePermute:
1829 case HloOpcode::kConvolution:
1830 case HloOpcode::kCustomCall:
1831 case HloOpcode::kReduceWindow:
1832 case HloOpcode::kSelectAndScatter:
1833 case HloOpcode::kPad:
1834 case HloOpcode::kDynamicSlice:
1835 case HloOpcode::kGather:
1836 case HloOpcode::kScatter:
1837 case HloOpcode::kDot:
1838 case HloOpcode::kDomain:
1839 case HloOpcode::kGetDimensionSize:
1840 case HloOpcode::kTriangularSolve:
1841 case HloOpcode::kCholesky:
1842 LOG(FATAL) << "Base class impl called for opcode with subclass: "
1843 << opcode();
1844 }
1845 return false;
1846 }
1847
HashOperand(const HloInstruction * hlo)1848 static uint64 HashOperand(const HloInstruction* hlo) {
1849 return ShapeUtil::Hash(hlo->shape());
1850 }
1851
Hash(const std::function<uint64 (const HloInstruction *)> & hash_operand) const1852 uint64 HloInstruction::Hash(
1853 const std::function<uint64(const HloInstruction*)>& hash_operand) const {
1854 using tensorflow::Hash64Combine;
1855
1856 uint64 hash_value = Hash64Combine(0, static_cast<uint64>(opcode()));
1857 hash_value = Hash64Combine(hash_value, ShapeUtil::Hash(shape()));
1858
1859 if (!IsCrossModuleAllReduce()) {
1860 if (!operands().empty()) {
1861 for (size_t i = 0; i < operands().size(); ++i) {
1862 hash_value = Hash64Combine(hash_value, hash_operand(operand(i)));
1863 }
1864 }
1865 }
1866
1867 hash_value = Hash64Combine(hash_value, InnerHash());
1868 return hash_value;
1869 }
1870
Hash() const1871 uint64 HloInstruction::Hash() const {
1872 // Use HashOperand as an argument to prevent non-termination.
1873 return Hash(HashOperand);
1874 }
1875
InnerHash() const1876 uint64 HloInstruction::InnerHash() const { return 13; }
1877
RemoveUser(HloInstruction * user)1878 void HloInstruction::RemoveUser(HloInstruction* user) {
1879 auto set_it = user_set_.find(user);
1880 CHECK(set_it != user_set_.end());
1881 user_set_.erase(set_it);
1882 // This is linear in the number of the users, but a vector provides a stable
1883 // iteration order and much faster traversal.
1884 auto vec_it = absl::c_find(users_, user);
1885 CHECK(vec_it != users_.end());
1886 users_.erase(vec_it);
1887 }
1888
ReplaceUseWith(HloInstruction * user,HloInstruction * new_producer)1889 Status HloInstruction::ReplaceUseWith(HloInstruction* user,
1890 HloInstruction* new_producer) {
1891 TF_RET_CHECK(
1892 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
1893 << "this shape: " << ShapeUtil::HumanString(shape())
1894 << ", replacement shape: "
1895 << ShapeUtil::HumanString(new_producer->shape());
1896 return ReplaceUseWithDifferentShape(user, new_producer);
1897 }
1898
ReplaceUseWithDifferentShape(HloInstruction * user,HloInstruction * new_producer)1899 Status HloInstruction::ReplaceUseWithDifferentShape(
1900 HloInstruction* user, HloInstruction* new_producer) {
1901 VLOG(3) << "Replacing uses of " << name() << " in " << user->name()
1902 << " with " << new_producer->name();
1903
1904 RemoveUser(user);
1905
1906 TF_RET_CHECK(absl::c_count(user->operands_, this) >= 0);
1907 std::replace(user->operands_.begin(), user->operands_.end(), this,
1908 new_producer);
1909 new_producer->AddUser(user);
1910 if (user->opcode() == HloOpcode::kFusion) {
1911 TF_RETURN_IF_ERROR(
1912 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
1913 }
1914 return Status::OK();
1915 }
1916
ReplaceOperandWith(int64 operand_num,HloInstruction * new_operand)1917 Status HloInstruction::ReplaceOperandWith(int64 operand_num,
1918 HloInstruction* new_operand) {
1919 auto old_operand = operand(operand_num);
1920 TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
1921 new_operand->shape()))
1922 << old_operand->shape() << " is not compatible with "
1923 << new_operand->shape();
1924 return ReplaceOperandWithDifferentShape(operand_num, new_operand);
1925 }
1926
ReplaceOperandWithDifferentShape(int64 operand_num,HloInstruction * new_operand)1927 Status HloInstruction::ReplaceOperandWithDifferentShape(
1928 int64 operand_num, HloInstruction* new_operand) {
1929 TF_RET_CHECK(operand_num >= 0);
1930 TF_RET_CHECK(operand_num < operand_count());
1931 HloInstruction* old_operand = mutable_operand(operand_num);
1932 if (old_operand == new_operand) {
1933 return Status::OK();
1934 }
1935
1936 operands_[operand_num] = new_operand;
1937
1938 VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with "
1939 << new_operand->name() << ", was " << old_operand->name();
1940
1941 if (!absl::c_linear_search(operands_, old_operand)) {
1942 old_operand->RemoveUser(this);
1943 }
1944 new_operand->AddUser(this);
1945 return Status::OK();
1946 }
1947
ReplaceAllUsesWith(HloInstruction * new_producer)1948 Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
1949 TF_RET_CHECK(
1950 ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
1951 << shape() << " is not compatible with " << new_producer->shape();
1952 return ReplaceAllUsesWithDifferentShape(new_producer);
1953 }
1954
ReplaceAllUsesWithDifferentShape(HloInstruction * new_producer)1955 Status HloInstruction::ReplaceAllUsesWithDifferentShape(
1956 HloInstruction* new_producer) {
1957 bool new_producer_is_user = false;
1958 for (HloInstruction* user : users()) {
1959 if (user == new_producer) {
1960 // It's possible that new_producer is a user of this instruction as might
1961 // be the case when replacing an instruction with a kCopy of itself. In
1962 // this case, don't do the replacement to avoid creating a cycle in the
1963 // graph. new_producer remains the only user of this instruction.
1964 new_producer_is_user = true;
1965 } else {
1966 std::replace(user->operands_.begin(), user->operands_.end(), this,
1967 new_producer);
1968 new_producer->AddUser(user);
1969 if (user->opcode() == HloOpcode::kFusion) {
1970 TF_RETURN_IF_ERROR(
1971 Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
1972 }
1973 }
1974 }
1975 users_.clear();
1976 user_set_.clear();
1977 if (new_producer_is_user) {
1978 AddUser(new_producer);
1979 }
1980 if (parent_ && parent_->root_instruction() == this) {
1981 parent_->set_root_instruction(new_producer,
1982 /*accept_different_shape=*/true);
1983 }
1984
1985 return Status::OK();
1986 }
1987
to_apply() const1988 HloComputation* HloInstruction::to_apply() const {
1989 switch (opcode_) {
1990 case HloOpcode::kCall:
1991 case HloOpcode::kMap:
1992 case HloOpcode::kReduceWindow:
1993 case HloOpcode::kReduce:
1994 case HloOpcode::kAllReduce:
1995 case HloOpcode::kScatter:
1996 case HloOpcode::kSort:
1997 CHECK_EQ(called_computations_.size(), 1);
1998 return called_computations_[0];
1999 default:
2000 LOG(FATAL) << "Invalid opcode for to_apply(): "
2001 << HloOpcodeString(opcode());
2002 }
2003 }
2004
set_to_apply(HloComputation * computation)2005 void HloInstruction::set_to_apply(HloComputation* computation) {
2006 // Don't allow changing the computation for fused instructions so we don't
2007 // have to recompute called_instructions for the entire fusion instruction.
2008 CHECK(!IsFused());
2009 switch (opcode_) {
2010 case HloOpcode::kCall:
2011 case HloOpcode::kMap:
2012 case HloOpcode::kReduceWindow:
2013 case HloOpcode::kReduce:
2014 case HloOpcode::kAllReduce:
2015 case HloOpcode::kScatter:
2016 case HloOpcode::kSort:
2017 CHECK_EQ(called_computations_.size(), 1);
2018 called_computations_[0] = computation;
2019 break;
2020 default:
2021 LOG(FATAL) << "Invalid opcode for to_apply(): "
2022 << HloOpcodeString(opcode());
2023 }
2024 }
2025
while_condition() const2026 HloComputation* HloInstruction::while_condition() const {
2027 CHECK_EQ(HloOpcode::kWhile, opcode_);
2028 return called_computations_[kConditionComputationIndex];
2029 }
2030
while_body() const2031 HloComputation* HloInstruction::while_body() const {
2032 CHECK_EQ(HloOpcode::kWhile, opcode_);
2033 return called_computations_[kBodyComputationIndex];
2034 }
2035
set_while_condition(HloComputation * computation)2036 void HloInstruction::set_while_condition(HloComputation* computation) {
2037 // Don't allow changing the computation for fused instructions so we don't
2038 // have to recompute called_instructions for the entire fusion instruction.
2039 CHECK(!IsFused());
2040 CHECK_EQ(HloOpcode::kWhile, opcode_);
2041 called_computations_[kConditionComputationIndex] = computation;
2042 }
2043
set_while_body(HloComputation * computation)2044 void HloInstruction::set_while_body(HloComputation* computation) {
2045 // Don't allow changing the computation for fused instructions so we don't
2046 // have to recompute called_instructions for the entire fusion instruction.
2047 CHECK(!IsFused());
2048 CHECK_EQ(HloOpcode::kWhile, opcode_);
2049 called_computations_[kBodyComputationIndex] = computation;
2050 }
2051
while_init() const2052 HloInstruction* HloInstruction::while_init() const {
2053 CHECK_EQ(HloOpcode::kWhile, opcode_);
2054 return operands_[0];
2055 }
2056
true_computation() const2057 HloComputation* HloInstruction::true_computation() const {
2058 CHECK_EQ(HloOpcode::kConditional, opcode_);
2059 CHECK_EQ(PRED, operand(0)->shape().element_type());
2060 return called_computations_[kTrueComputationIndex];
2061 }
2062
false_computation() const2063 HloComputation* HloInstruction::false_computation() const {
2064 CHECK_EQ(HloOpcode::kConditional, opcode_);
2065 CHECK_EQ(PRED, operand(0)->shape().element_type());
2066 return called_computations_[kFalseComputationIndex];
2067 }
2068
branch_computations() const2069 const std::vector<HloComputation*>& HloInstruction::branch_computations()
2070 const {
2071 CHECK(HloOpcode::kConditional == opcode_);
2072 return called_computations_;
2073 }
2074
branch_count() const2075 int HloInstruction::branch_count() const {
2076 CHECK(HloOpcode::kConditional == opcode_);
2077 return called_computations_.size();
2078 }
2079
branch_computation(int b) const2080 HloComputation* HloInstruction::branch_computation(int b) const {
2081 CHECK(HloOpcode::kConditional == opcode_);
2082 CHECK_GE(b, 0);
2083 CHECK_LT(b, called_computations_.size());
2084 return called_computations_[b];
2085 }
2086
set_branch_computation(int b,HloComputation * computation)2087 void HloInstruction::set_branch_computation(int b,
2088 HloComputation* computation) {
2089 // Don't allow changing the computation for fused instructions so we don't
2090 // have to recompute called_instructions for the entire fusion instruction.
2091 CHECK(!IsFused());
2092 CHECK_EQ(HloOpcode::kConditional, opcode_);
2093 called_computations_[b] = computation;
2094 }
2095
SignatureString() const2096 string HloInstruction::SignatureString() const {
2097 string operands =
2098 StrJoin(operands_, ", ", [](string* out, HloInstruction* operand) {
2099 StrAppend(out, ShapeUtil::HumanString(operand->shape()));
2100 });
2101 return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape()));
2102 }
2103
2104 namespace {
2105
PrintName(const string & name,const HloPrintOptions & options)2106 string PrintName(const string& name, const HloPrintOptions& options) {
2107 return StrCat(options.print_percent() ? "%" : "", name);
2108 }
2109
2110 } // namespace
2111
ToString(const HloPrintOptions & options) const2112 string HloInstruction::ToString(const HloPrintOptions& options) const {
2113 CanonicalNameMap new_map;
2114 return ToStringWithCanonicalNameMap(options, &new_map);
2115 }
2116
IsElementwiseImpl(const absl::optional<int64> & operand_idx) const2117 bool HloInstruction::IsElementwiseImpl(
2118 const absl::optional<int64>& operand_idx) const {
2119 switch (opcode_) {
2120 // Unary elementwise operations.
2121 case HloOpcode::kAbs:
2122 case HloOpcode::kRoundNearestAfz:
2123 case HloOpcode::kCeil:
2124 case HloOpcode::kClz:
2125 case HloOpcode::kConvert:
2126 case HloOpcode::kBitcastConvert:
2127 case HloOpcode::kCopy:
2128 case HloOpcode::kCos:
2129 case HloOpcode::kExp:
2130 case HloOpcode::kExpm1:
2131 case HloOpcode::kFloor:
2132 case HloOpcode::kImag:
2133 case HloOpcode::kIsFinite:
2134 case HloOpcode::kLog:
2135 case HloOpcode::kLog1p:
2136 case HloOpcode::kNot:
2137 case HloOpcode::kNegate:
2138 case HloOpcode::kReal:
2139 case HloOpcode::kReducePrecision:
2140 case HloOpcode::kRsqrt:
2141 case HloOpcode::kSign:
2142 case HloOpcode::kSin:
2143 case HloOpcode::kSqrt:
2144 case HloOpcode::kTanh:
2145 CHECK_EQ(1, operand_count());
2146 return true;
2147
2148 // Binary elementwise operations, the same as in IsElementwiseBinary().
2149 case HloOpcode::kAdd:
2150 case HloOpcode::kAtan2:
2151 case HloOpcode::kCompare:
2152 case HloOpcode::kComplex:
2153 case HloOpcode::kDivide:
2154 case HloOpcode::kMaximum:
2155 case HloOpcode::kMinimum:
2156 case HloOpcode::kMultiply:
2157 case HloOpcode::kPower:
2158 case HloOpcode::kRemainder:
2159 case HloOpcode::kSubtract:
2160 case HloOpcode::kAnd:
2161 case HloOpcode::kOr:
2162 case HloOpcode::kXor:
2163 case HloOpcode::kShiftLeft:
2164 case HloOpcode::kShiftRightArithmetic:
2165 case HloOpcode::kShiftRightLogical:
2166 CHECK_EQ(2, operand_count());
2167 return true;
2168
2169 // Ternary elementwise operations.
2170 case HloOpcode::kSelect:
2171 case HloOpcode::kClamp:
2172 return true;
2173
2174 case HloOpcode::kDynamicUpdateSlice:
2175 return operand_idx.has_value() && operand_idx.value() == 0;
2176
2177 default:
2178 return false;
2179 }
2180 }
2181
IsCrossModuleAllReduce() const2182 bool HloInstruction::IsCrossModuleAllReduce() const {
2183 return opcode() == HloOpcode::kAllReduce && all_reduce_id();
2184 }
2185
IsCrossReplicaAllReduce() const2186 bool HloInstruction::IsCrossReplicaAllReduce() const {
2187 return opcode() == HloOpcode::kAllReduce && !all_reduce_id();
2188 }
2189
ToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2190 string HloInstruction::ToStringWithCanonicalNameMap(
2191 const HloPrintOptions& options,
2192 CanonicalNameMap* canonical_name_map) const {
2193 string result = "";
2194
2195 // Logic to print the instruction name (e.g. "%foo = ").
2196 if (options.canonicalize_instruction_names()) {
2197 if (options.is_in_nested_computation()) {
2198 // If we are canonicalizing instruction names and this is a top-level
2199 // HloInstruction::ToString() call, don't print an instruction name.
2200 StrAppend(&result,
2201 PrintName(canonical_name_map->LookupOrInsert(name()), options),
2202 " = ");
2203 }
2204 } else {
2205 StrAppend(&result, PrintName(name(), options), " = ");
2206 }
2207
2208 // Print opcode, operand(s) and shape.
2209 StrAppend(&result, ShapeUtil::HumanStringWithLayout(shape()), " ",
2210 HloOpcodeString(opcode()), "(",
2211 OperandsToStringWithCanonicalNameMap(options, canonical_name_map),
2212 ")");
2213
2214 // Print additional attributes. If an instruction contains a subcomputation,
2215 // the subcomputation is also printed here.
2216 for (const string& extra : ExtraAttributesToString(options)) {
2217 StrAppend(&result, ", ", extra);
2218 }
2219
2220 if (options.print_metadata() &&
2221 (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
2222 !metadata_.source_file().empty())) {
2223 StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}");
2224 }
2225 if (options.print_backend_config() && !backend_config_.empty()) {
2226 StrAppend(&result, ", backend_config=\"", CEscape(backend_config_), "\"");
2227 }
2228 return result;
2229 }
2230
OperandsToString(const HloPrintOptions & options) const2231 string HloInstruction::OperandsToString(const HloPrintOptions& options) const {
2232 CanonicalNameMap new_map;
2233 return OperandsToStringWithCanonicalNameMap(options, &new_map);
2234 }
2235
OperandsToStringWithCanonicalNameMap(const HloPrintOptions & options,CanonicalNameMap * canonical_name_map) const2236 string HloInstruction::OperandsToStringWithCanonicalNameMap(
2237 const HloPrintOptions& options,
2238 CanonicalNameMap* canonical_name_map) const {
2239 string operands;
2240 absl::Span<HloInstruction* const> slice(operands_);
2241 const int64 kMaxOperandsToShowIfCompact = 4;
2242 if (options.compact_operands() &&
2243 slice.size() > kMaxOperandsToShowIfCompact) {
2244 slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact);
2245 }
2246 operands = StrJoin(slice, ", ", [&](string* out, HloInstruction* operand) {
2247 // If operand is already been deleted, put `null` to the string output.
2248 if (operand == nullptr) {
2249 StrAppend(out, "null ");
2250 return;
2251 }
2252 std::vector<string> str;
2253 if (options.print_operand_shape()) {
2254 str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape()));
2255 }
2256
2257 // In a top-level HloInstruction::ToString() call, the operand name is not
2258 // part of the canonical string.
2259 if (options.canonicalize_instruction_names() &&
2260 options.is_in_nested_computation()) {
2261 str.push_back(PrintName(
2262 canonical_name_map->LookupOrInsert(operand->name()), options));
2263 } else if (options.print_operand_names()) {
2264 str.push_back(PrintName(operand->name(), options));
2265 }
2266 StrAppend(out, StrJoin(str, " "));
2267 });
2268 const int64 remaining = operands_.size() - slice.size();
2269 if (slice.size() != operands_.size()) {
2270 StrAppend(&operands, ", ...(+", remaining, ")");
2271 }
2272 return operands;
2273 }
2274
ExtraAttributesToString(const HloPrintOptions & options) const2275 std::vector<string> HloInstruction::ExtraAttributesToString(
2276 const HloPrintOptions& options) const {
2277 std::vector<string> extra = ExtraAttributesToStringImpl(options);
2278
2279 if (options.print_subcomputation_mode() ==
2280 HloPrintOptions::PrintSubcomputationMode::kNameOnly) {
2281 if (opcode() == HloOpcode::kWhile) {
2282 extra.push_back(
2283 StrCat("condition=", PrintName(while_condition()->name(), options)));
2284 extra.push_back(
2285 StrCat("body=", PrintName(while_body()->name(), options)));
2286 } else if (opcode() == HloOpcode::kSelectAndScatter) {
2287 extra.push_back(StrCat("select=", PrintName(select()->name(), options)));
2288 extra.push_back(
2289 StrCat("scatter=", PrintName(scatter()->name(), options)));
2290 } else if (opcode() == HloOpcode::kConditional) {
2291 if (operand(0)->shape().element_type() == PRED) {
2292 extra.push_back(StrCat("true_computation=",
2293 PrintName(true_computation()->name(), options)));
2294 extra.push_back(
2295 StrCat("false_computation=",
2296 PrintName(false_computation()->name(), options)));
2297 } else {
2298 extra.push_back(StrCat(
2299 "branch_computations={",
2300 StrJoin(branch_computations(), ", ",
2301 [&](string* out, const HloComputation* computation) {
2302 StrAppend(out, PrintName(computation->name(), options));
2303 }),
2304 "}"));
2305 }
2306 } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
2307 opcode() == HloOpcode::kReduceWindow ||
2308 opcode() == HloOpcode::kReduce ||
2309 opcode() == HloOpcode::kAllReduce ||
2310 opcode() == HloOpcode::kScatter ||
2311 opcode() == HloOpcode::kSort) {
2312 extra.push_back(
2313 StrCat("to_apply=", PrintName(to_apply()->name(), options)));
2314 } else if (!called_computations().empty()) {
2315 extra.push_back(StrCat(
2316 "calls=",
2317 StrJoin(called_computations(), ", ",
2318 [&](string* out, const HloComputation* computation) {
2319 StrAppend(out, PrintName(computation->name(), options));
2320 })));
2321 }
2322 } else if (options.print_subcomputation_mode() ==
2323 HloPrintOptions::PrintSubcomputationMode::kFullBodies) {
2324 HloPrintOptions new_options = options;
2325 new_options.set_is_in_nested_computation(true);
2326 switch (opcode()) {
2327 case HloOpcode::kWhile:
2328 extra.push_back(
2329 StrCat("condition=\n", while_condition()->ToString(new_options)));
2330 extra.push_back(StrCat("body=\n", while_body()->ToString(new_options)));
2331 break;
2332 case HloOpcode::kSelectAndScatter:
2333 extra.push_back(StrCat("select=\n", select()->ToString(new_options)));
2334 extra.push_back(StrCat("scatter=\n", scatter()->ToString(new_options)));
2335 break;
2336 case HloOpcode::kConditional:
2337 if (operand(0)->shape().element_type() == PRED) {
2338 extra.push_back(StrCat("true_computation=\n",
2339 true_computation()->ToString(new_options)));
2340 extra.push_back(StrCat("false_computation=\n",
2341 false_computation()->ToString(new_options)));
2342 } else {
2343 extra.push_back(StrCat(
2344 "branch_computations={\n",
2345 StrJoin(branch_computations(), ",\n",
2346 [&](string* out, const HloComputation* computation) {
2347 StrAppend(out, computation->ToString(new_options));
2348 }),
2349 "\n}"));
2350 }
2351 break;
2352 case HloOpcode::kCall:
2353 case HloOpcode::kMap:
2354 case HloOpcode::kReduceWindow:
2355 case HloOpcode::kReduce:
2356 case HloOpcode::kAllReduce:
2357 case HloOpcode::kScatter:
2358 case HloOpcode::kSort:
2359 extra.push_back(
2360 StrCat("to_apply=\n", to_apply()->ToString(new_options)));
2361 break;
2362 default:
2363 if (!called_computations().empty()) {
2364 extra.push_back(StrCat(
2365 "calls=\n",
2366 StrJoin(called_computations(), ", ",
2367 [&](string* out, const HloComputation* computation) {
2368 StrAppend(out, computation->ToString(new_options));
2369 })));
2370 }
2371 break;
2372 }
2373 }
2374
2375 if (has_sharding()) {
2376 extra.push_back(StrCat("sharding=", sharding().ToString()));
2377 }
2378 if (options.print_control_dependencies() && !control_predecessors_.empty()) {
2379 extra.push_back(StrCat("control-predecessors={",
2380 StrJoin(control_predecessors_, ", ",
2381 [&](string* out, HloInstruction* pre) {
2382 StrAppend(out,
2383 PrintName(pre->name(), options));
2384 }),
2385 "}"));
2386 }
2387
2388 return extra;
2389 }
2390
ToShortString() const2391 string HloInstruction::ToShortString() const {
2392 return StrCat("%", name(), " = ", HloOpcodeString(opcode()), "(",
2393 StrJoin(operands_, ", ",
2394 [](string* out, HloInstruction* operand) {
2395 StrAppend(out, "%", operand->name());
2396 }),
2397 ")");
2398 }
2399
ToProto() const2400 HloInstructionProto HloInstruction::ToProto() const {
2401 HloInstructionProto proto;
2402 CHECK(unique_id_ != -1)
2403 << "This instruction does not have a valid id. Please make sure the "
2404 "instruction is inside a module before dumping it.";
2405 proto.set_id(unique_id_);
2406 proto.set_name(name_);
2407 proto.set_opcode(HloOpcodeString(opcode_));
2408 *proto.mutable_shape() = shape_.ToProto();
2409 for (const HloInstruction* operand : operands_) {
2410 proto.add_operand_ids(operand->unique_id());
2411 }
2412 for (const HloInstruction* control : control_predecessors_) {
2413 proto.add_control_predecessor_ids(control->unique_id());
2414 }
2415
2416 *proto.mutable_metadata() = metadata_;
2417 proto.set_backend_config(backend_config_);
2418 if (opcode() != HloOpcode::kFusion) {
2419 for (const HloComputation* computation : called_computations_) {
2420 proto.add_called_computation_ids(computation->unique_id());
2421 }
2422 }
2423
2424 if (has_sharding()) {
2425 *proto.mutable_sharding() = sharding().ToProto();
2426 }
2427
2428 return proto;
2429 }
2430
ToCategory() const2431 string HloInstruction::ToCategory() const {
2432 if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy ||
2433 opcode() == HloOpcode::kReshape) {
2434 return "data formatting";
2435 }
2436
2437 if (IsElementwise()) {
2438 return "non-fusion elementwise";
2439 }
2440
2441 return HloOpcodeString(opcode());
2442 }
2443
tracing() const2444 HloInstruction* HloInstruction::tracing() const { return trace_instruction_; }
2445
set_tracing(HloInstruction * trace_instruction)2446 void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
2447 trace_instruction_ = trace_instruction;
2448 }
2449
IsFused() const2450 bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
2451
IsFusible() const2452 bool HloInstruction::IsFusible() const {
2453 // Instructions which are traced should not be fused.
2454 if (tracing()) {
2455 return false;
2456 }
2457 // Some kinds of instructions don't make sense to fuse.
2458 switch (opcode_) {
2459 case HloOpcode::kDomain:
2460 case HloOpcode::kParameter:
2461 return false;
2462 // Side effecting instrutions cannot be fused.
2463 default:
2464 return !HasSideEffect();
2465 }
2466 }
2467
HloInstruction(HloOpcode opcode,const Shape & shape)2468 HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape)
2469 : unique_id_(-1),
2470 opcode_(opcode),
2471 shape_(shape),
2472 name_(HloOpcodeString(opcode)) {
2473 TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_));
2474 }
2475
2476 template <typename HloInstructionPtr>
Visit(DfsHloVisitorBase<HloInstructionPtr> * visitor)2477 Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
2478 switch (opcode_) {
2479 case HloOpcode::kAbs:
2480 return visitor->HandleAbs(this);
2481 case HloOpcode::kAtan2:
2482 return visitor->HandleAtan2(this);
2483 case HloOpcode::kRoundNearestAfz:
2484 return visitor->HandleRound(this);
2485 case HloOpcode::kBatchNormTraining:
2486 return visitor->HandleBatchNormTraining(this);
2487 case HloOpcode::kBatchNormInference:
2488 return visitor->HandleBatchNormInference(this);
2489 case HloOpcode::kBatchNormGrad:
2490 return visitor->HandleBatchNormGrad(this);
2491 case HloOpcode::kSign:
2492 return visitor->HandleSign(this);
2493 case HloOpcode::kConstant:
2494 return visitor->HandleConstant(this);
2495 case HloOpcode::kGetTupleElement:
2496 return visitor->HandleGetTupleElement(this);
2497 case HloOpcode::kParameter:
2498 return visitor->HandleParameter(this);
2499 case HloOpcode::kCompare:
2500 return visitor->HandleCompare(this);
2501 case HloOpcode::kComplex:
2502 return visitor->HandleComplex(this);
2503 case HloOpcode::kAdd:
2504 return visitor->HandleAdd(this);
2505 case HloOpcode::kDivide:
2506 return visitor->HandleDivide(this);
2507 case HloOpcode::kSubtract:
2508 return visitor->HandleSubtract(this);
2509 case HloOpcode::kMaximum:
2510 return visitor->HandleMaximum(this);
2511 case HloOpcode::kMinimum:
2512 return visitor->HandleMinimum(this);
2513 case HloOpcode::kAnd:
2514 return visitor->HandleAnd(this);
2515 case HloOpcode::kOr:
2516 return visitor->HandleOr(this);
2517 case HloOpcode::kXor:
2518 return visitor->HandleXor(this);
2519 case HloOpcode::kShiftLeft:
2520 return visitor->HandleShiftLeft(this);
2521 case HloOpcode::kShiftRightArithmetic:
2522 return visitor->HandleShiftRightArithmetic(this);
2523 case HloOpcode::kShiftRightLogical:
2524 return visitor->HandleShiftRightLogical(this);
2525 case HloOpcode::kConcatenate:
2526 return visitor->HandleConcatenate(this);
2527 case HloOpcode::kConvert:
2528 return visitor->HandleConvert(this);
2529 case HloOpcode::kBitcastConvert:
2530 return visitor->HandleBitcastConvert(this);
2531 case HloOpcode::kCopy:
2532 return visitor->HandleCopy(this);
2533 case HloOpcode::kMultiply:
2534 return visitor->HandleMultiply(this);
2535 case HloOpcode::kDot:
2536 return visitor->HandleDot(this);
2537 case HloOpcode::kPower:
2538 return visitor->HandlePower(this);
2539 case HloOpcode::kRemainder:
2540 return visitor->HandleRemainder(this);
2541 case HloOpcode::kSelect:
2542 return visitor->HandleSelect(this);
2543 case HloOpcode::kTupleSelect:
2544 return visitor->HandleTupleSelect(this);
2545 case HloOpcode::kConvolution:
2546 return visitor->HandleConvolution(this);
2547 case HloOpcode::kFft:
2548 return visitor->HandleFft(this);
2549 case HloOpcode::kAllReduce:
2550 return visitor->HandleAllReduce(this);
2551 case HloOpcode::kAllToAll:
2552 return visitor->HandleAllToAll(this);
2553 case HloOpcode::kCollectivePermute:
2554 return visitor->HandleCollectivePermute(this);
2555 case HloOpcode::kReplicaId:
2556 return visitor->HandleReplicaId(this);
2557 case HloOpcode::kTuple:
2558 return visitor->HandleTuple(this);
2559 case HloOpcode::kMap:
2560 return visitor->HandleMap(this);
2561 case HloOpcode::kClamp:
2562 return visitor->HandleClamp(this);
2563 case HloOpcode::kReduce:
2564 return visitor->HandleReduce(this);
2565 case HloOpcode::kReduceWindow:
2566 return visitor->HandleReduceWindow(this);
2567 case HloOpcode::kSelectAndScatter:
2568 return visitor->HandleSelectAndScatter(this);
2569 case HloOpcode::kNegate:
2570 return visitor->HandleNegate(this);
2571 case HloOpcode::kExp:
2572 return visitor->HandleExp(this);
2573 case HloOpcode::kExpm1:
2574 return visitor->HandleExpm1(this);
2575 case HloOpcode::kFloor:
2576 return visitor->HandleFloor(this);
2577 case HloOpcode::kCeil:
2578 return visitor->HandleCeil(this);
2579 case HloOpcode::kClz:
2580 return visitor->HandleClz(this);
2581 case HloOpcode::kLog:
2582 return visitor->HandleLog(this);
2583 case HloOpcode::kLog1p:
2584 return visitor->HandleLog1p(this);
2585 case HloOpcode::kTanh:
2586 return visitor->HandleTanh(this);
2587 case HloOpcode::kCos:
2588 return visitor->HandleCos(this);
2589 case HloOpcode::kSin:
2590 return visitor->HandleSin(this);
2591 case HloOpcode::kSqrt:
2592 return visitor->HandleSqrt(this);
2593 case HloOpcode::kRsqrt:
2594 return visitor->HandleRsqrt(this);
2595 case HloOpcode::kReal:
2596 return visitor->HandleReal(this);
2597 case HloOpcode::kImag:
2598 return visitor->HandleImag(this);
2599 case HloOpcode::kIsFinite:
2600 return visitor->HandleIsFinite(this);
2601 case HloOpcode::kNot:
2602 return visitor->HandleNot(this);
2603 case HloOpcode::kBitcast:
2604 return visitor->HandleBitcast(this);
2605 case HloOpcode::kBroadcast:
2606 return visitor->HandleBroadcast(this);
2607 case HloOpcode::kPad:
2608 return visitor->HandlePad(this);
2609 case HloOpcode::kReshape:
2610 return visitor->HandleReshape(this);
2611 case HloOpcode::kTranspose:
2612 return visitor->HandleTranspose(this);
2613 case HloOpcode::kReverse:
2614 return visitor->HandleReverse(this);
2615 case HloOpcode::kReducePrecision:
2616 return visitor->HandleReducePrecision(this);
2617 case HloOpcode::kSlice:
2618 return visitor->HandleSlice(this);
2619 case HloOpcode::kDynamicSlice:
2620 return visitor->HandleDynamicSlice(this);
2621 case HloOpcode::kDynamicUpdateSlice:
2622 return visitor->HandleDynamicUpdateSlice(this);
2623 case HloOpcode::kSort:
2624 return visitor->HandleSort(this);
2625 case HloOpcode::kInfeed:
2626 return visitor->HandleInfeed(this);
2627 case HloOpcode::kOutfeed:
2628 return visitor->HandleOutfeed(this);
2629 case HloOpcode::kRng:
2630 return visitor->HandleRng(this);
2631 case HloOpcode::kWhile:
2632 return visitor->HandleWhile(this);
2633 case HloOpcode::kFusion:
2634 return visitor->HandleFusion(this);
2635 case HloOpcode::kCall:
2636 return visitor->HandleCall(this);
2637 case HloOpcode::kConditional:
2638 return visitor->HandleConditional(this);
2639 case HloOpcode::kCustomCall:
2640 return visitor->HandleCustomCall(this);
2641 case HloOpcode::kRecv:
2642 return visitor->HandleRecv(this);
2643 case HloOpcode::kRecvDone:
2644 return visitor->HandleRecvDone(this);
2645 case HloOpcode::kSend:
2646 return visitor->HandleSend(this);
2647 case HloOpcode::kSendDone:
2648 return visitor->HandleSendDone(this);
2649 case HloOpcode::kGather:
2650 return visitor->HandleGather(this);
2651 case HloOpcode::kScatter:
2652 return visitor->HandleScatter(this);
2653 case HloOpcode::kDomain:
2654 return visitor->HandleDomain(this);
2655 case HloOpcode::kAfterAll:
2656 return visitor->HandleAfterAll(this);
2657 case HloOpcode::kAddDependency:
2658 return visitor->HandleAddDependency(this);
2659 case HloOpcode::kIota:
2660 return visitor->HandleIota(this);
2661 case HloOpcode::kGetDimensionSize:
2662 return visitor->HandleGetDimensionSize(this);
2663 case HloOpcode::kTriangularSolve:
2664 return visitor->HandleTriangularSolve(this);
2665 case HloOpcode::kCholesky:
2666 return visitor->HandleCholesky(this);
2667
2668 // These opcodes are not handled here.
2669 case HloOpcode::kTrace:
2670 break;
2671 }
2672 return InternalError(
2673 "Unhandled HloOpcode for DfsHloVisitor: %s. This should not happen - "
2674 "please file a bug for XLA.",
2675 HloOpcodeString(opcode_));
2676 }
2677
2678 // Explicit instantiations.
2679 template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2680 template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2681
2682 using DFSStack = absl::InlinedVector<std::pair<int, HloInstruction*>, 16>;
2683
2684 // Push "child" onto the dfs_stack if not already visited. Returns false if a
2685 // cycle was detected, and true otherwise.
2686 template <typename Visitor>
PushDFSChild(Visitor * visitor,DFSStack * dfs_stack,HloInstruction * child)2687 inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack,
2688 HloInstruction* child) {
2689 CHECK(child != nullptr);
2690 const int id = child->unique_id();
2691 CHECK_GE(id, 0) << "instruction may not have a parent computation";
2692 switch (visitor->GetVisitState(id)) {
2693 case Visitor::kVisiting:
2694 return false;
2695
2696 case Visitor::kVisited:
2697 // Nothing to do
2698 return true;
2699
2700 case Visitor::kNotVisited:
2701 dfs_stack->push_back(std::make_pair(id, child));
2702 return true;
2703 }
2704 }
2705
2706 using InternalCompareFunction =
2707 std::function<bool(std::pair<int, const HloInstruction*>,
2708 std::pair<int, const HloInstruction*>)>;
2709 template <typename Visitor>
PostOrderDFS(HloInstruction * root,Visitor * visitor,const InternalCompareFunction * operand_order,bool ignore_control_predecessors)2710 static Status PostOrderDFS(HloInstruction* root, Visitor* visitor,
2711 const InternalCompareFunction* operand_order,
2712 bool ignore_control_predecessors) {
2713 visitor->ReserveVisitStates(root->GetModule()->instruction_count());
2714
2715 // dfs_stack holds pairs of <HloInstruction*->unique_id(), HloInstruction*>.
2716 //
2717 // We need to keep track of both the id and the instruction because
2718 // instructions can get deleted while they are on the stack, so we
2719 // can't always use the (potentially dead) instruction object to grab
2720 // its id.
2721 DFSStack dfs_stack;
2722 dfs_stack.emplace_back(root->unique_id(), root);
2723
2724 do {
2725 DCHECK(!dfs_stack.empty());
2726
2727 int current_id = dfs_stack.back().first;
2728 HloInstruction* current_node = dfs_stack.back().second;
2729 CHECK_GE(current_id, 0) << current_id << ": " << current_node
2730 << ": instruction may not have parent computation";
2731 typename Visitor::VisitState visit_state =
2732 visitor->GetVisitState(current_id);
2733 if (visit_state == Visitor::kVisited) {
2734 dfs_stack.pop_back();
2735 VLOG(3) << "Not visiting HLO %" << current_node->name()
2736 << " as it was already visited.";
2737 continue;
2738 }
2739
2740 if (visit_state == Visitor::kVisiting) {
2741 dfs_stack.pop_back();
2742
2743 TF_RETURN_IF_ERROR(visitor->Preprocess(current_node));
2744 VLOG(2) << "Visiting HLO %" << current_node->name();
2745 TF_RETURN_IF_ERROR(current_node->Visit(visitor));
2746 visitor->SetVisitState(current_id, Visitor::kVisited);
2747 TF_RETURN_IF_ERROR(visitor->Postprocess(current_node));
2748 continue;
2749 }
2750
2751 visitor->SetVisitState(current_id, Visitor::kVisiting);
2752
2753 const size_t old_dfs_stack_size = dfs_stack.size();
2754 for (HloInstruction* child : current_node->operands()) {
2755 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
2756 return FailedPrecondition(
2757 "A cycle is detected while visiting instruction %s",
2758 current_node->ToString());
2759 }
2760 }
2761
2762 if (!ignore_control_predecessors) {
2763 for (HloInstruction* child : current_node->control_predecessors()) {
2764 if (!TF_PREDICT_TRUE(PushDFSChild(visitor, &dfs_stack, child))) {
2765 return FailedPrecondition(
2766 "A cycle is detected while visiting instruction %s",
2767 current_node->ToString());
2768 }
2769 }
2770 }
2771
2772 if (operand_order != nullptr) {
2773 std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(),
2774 *operand_order);
2775 }
2776
2777 // This makes the traversal order the same as what you'd expect
2778 // out of a recursive algorithm.
2779 std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end());
2780 } while (!dfs_stack.empty());
2781
2782 return Status::OK();
2783 }
2784
2785 template <typename HloInstructionPtr>
Accept(DfsHloVisitorBase<HloInstructionPtr> * visitor,bool call_finish_visit,bool ignore_control_predecessors)2786 Status HloInstruction::Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
2787 bool call_finish_visit,
2788 bool ignore_control_predecessors) {
2789 VLOG(3) << "HloInstruction::Accept(%" << name() << ")";
2790 TF_RETURN_IF_ERROR(
2791 PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors));
2792 if (call_finish_visit) {
2793 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
2794 }
2795 return Status::OK();
2796 }
2797
2798 // Explicit instantiations.
2799 template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2800 template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2801
AcceptWithOperandOrder(DfsHloVisitor * visitor,const CompareFunction & operand_order,bool call_finish_visit)2802 Status HloInstruction::AcceptWithOperandOrder(
2803 DfsHloVisitor* visitor, const CompareFunction& operand_order,
2804 bool call_finish_visit) {
2805 VLOG(2) << "HloInstruction::AcceptWithOperandOrder(%" << name() << ")";
2806 InternalCompareFunction func = [&operand_order](
2807 std::pair<int, const HloInstruction*> a,
2808 std::pair<int, const HloInstruction*> b) {
2809 // Call the client's comparison function on the actual HloInstruction*
2810 // objects (ignoring the internal ids we also have in our stack entries)
2811 return operand_order(a.second, b.second);
2812 };
2813 TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &func,
2814 /*ignore_control_predecessors=*/false));
2815 if (call_finish_visit) {
2816 VLOG(3) << "HloInstruction::AcceptWithOperandOrder BEFORE FINISH VISIT";
2817 TF_RETURN_IF_ERROR(visitor->FinishVisit(this));
2818 VLOG(3) << "HloInstruction::AcceptWithOperandOrder AFTER FINISH VISIT";
2819 }
2820 VLOG(2) << "HloInstruction::AcceptWithOperandOrder EXIT";
2821 return Status::OK();
2822 }
2823
Accept(const std::function<Status (HloInstruction *)> & visitor_func)2824 Status HloInstruction::Accept(
2825 const std::function<Status(HloInstruction*)>& visitor_func) {
2826 FunctionVisitor visitor(visitor_func);
2827 return this->Accept(&visitor);
2828 }
2829
Accept(const std::function<Status (const HloInstruction *)> & visitor_func) const2830 Status HloInstruction::Accept(
2831 const std::function<Status(const HloInstruction*)>& visitor_func) const {
2832 ConstFunctionVisitor visitor(visitor_func);
2833 return this->Accept(&visitor);
2834 }
2835
shape() const2836 const Shape& HloInstruction::shape() const { return shape_; }
2837
OperandIndices(const HloInstruction * operand) const2838 std::vector<int64> HloInstruction::OperandIndices(
2839 const HloInstruction* operand) const {
2840 std::vector<int64> result;
2841 for (int64 i = 0; i < operand_count(); ++i) {
2842 if (this->operand(i) == operand) {
2843 result.push_back(i);
2844 }
2845 }
2846 return result;
2847 }
2848
IsElementwiseBinary() const2849 bool HloInstruction::IsElementwiseBinary() const {
2850 return IsElementwise() && operand_count() == 2;
2851 }
2852
IsElementwise() const2853 bool HloInstruction::IsElementwise() const {
2854 return IsElementwiseImpl(absl::nullopt);
2855 }
2856
IsElementwiseOnOperand(int64 operand_idx) const2857 bool HloInstruction::IsElementwiseOnOperand(int64 operand_idx) const {
2858 return IsElementwiseImpl(operand_idx);
2859 }
2860
2861 // A helper class for memoized, recursive computation of HloOpcode::kFusion
2862 // in HloInstruction::OperandElementUse below.
2863 class HloInstruction::FusionReusesParamElements {
2864 public:
2865 using UseKind = HloInstruction::UseKind;
2866
2867 // We could rather iterate backwards through fused_instructions_ here, as it
2868 // is in reverse postorder, and compute whether each fused instruction reuses
2869 // the value of this parameter, which would save stack space but not allow us
2870 // to finish early if we find a reuse.
Compute(int64 i,const HloInstruction & hlo)2871 static UseKind Compute(int64 i, const HloInstruction& hlo) {
2872 absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
2873 return ComputeInternal(i, hlo, &memoization_cache);
2874 }
2875
2876 private:
ComputeInternal(int64 i,const HloInstruction & hlo,absl::flat_hash_map<const HloInstruction *,UseKind> * cache)2877 static UseKind ComputeInternal(
2878 int64 i, const HloInstruction& hlo,
2879 absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
2880 if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
2881 if (hlo_param->parameter_number() == i) {
2882 return UseKind::kUse;
2883 }
2884 }
2885
2886 auto p = cache->emplace(&hlo, UseKind{});
2887 auto value_it = p.first;
2888 const bool key_is_new = p.second;
2889
2890 if (key_is_new) {
2891 for (int64 j = 0; j < hlo.operands_.size(); ++j) {
2892 UseKind old_val = value_it->second;
2893
2894 // The next operation invalidates iterators.
2895 UseKind new_val =
2896 Plus(old_val, std::min(hlo.OperandElementUse(j),
2897 ComputeInternal(i, *hlo.operand(j), cache)));
2898
2899 // Re-acquire the iterator. We could work harder to do this only if
2900 // absolutely necessary, but this code is not hot enough to warrant
2901 // that.
2902 value_it = cache->find(&hlo);
2903 value_it->second = new_val;
2904 }
2905 }
2906 return value_it->second;
2907 }
2908
2909 // Fold operation for UseKinds.
Plus(UseKind a,UseKind b)2910 static UseKind Plus(UseKind a, UseKind b) {
2911 if (a == UseKind::kNoUse) {
2912 return b;
2913 } else if (b == UseKind::kNoUse) {
2914 return a;
2915 } else if (a == UseKind::kReuse || b == UseKind::kReuse) {
2916 return UseKind::kReuse;
2917 } else if (a == UseKind::kUsePermutingElements ||
2918 b == UseKind::kUsePermutingElements) {
2919 return UseKind::kReuse;
2920 } else {
2921 CHECK(a == UseKind::kUse && b == UseKind::kUse);
2922 return UseKind::kUse;
2923 }
2924 }
2925 };
2926
OperandElementUse(int64 i) const2927 HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const {
2928 switch (opcode_) {
2929 case HloOpcode::kBitcast:
2930 case HloOpcode::kConcatenate:
2931 case HloOpcode::kReshape:
2932 case HloOpcode::kReverse:
2933 case HloOpcode::kSlice:
2934 case HloOpcode::kTranspose:
2935 return UseKind::kUsePermutingElements;
2936 case HloOpcode::kPad:
2937 // Pad reuses the padding value but not the padded array elements.
2938 return i > 0 ? UseKind::kReuse : UseKind::kUsePermutingElements;
2939 case HloOpcode::kReduce:
2940 // Reduce reuses the init values but not the operand array elements.
2941 return i >= Cast<HloReduceInstruction>(this)->input_count()
2942 ? UseKind::kReuse
2943 : UseKind::kUsePermutingElements;
2944 case HloOpcode::kFusion:
2945 // Uses the memoizing, recursive computation defined above.
2946 return FusionReusesParamElements::Compute(i, *fused_expression_root());
2947 case HloOpcode::kDot:
2948 // Dot operations with inputs [A,B] * [B,1] do not re-use
2949 // elements on their left operand.
2950 // Dot operations with inputs [1,A] * [A,B] do not re-use
2951 // elements on their right operand.
2952 if (shape().dimensions_size() == 2) {
2953 if ((i == 0 && shape().dimensions(1) == 1) ||
2954 (i == 1 && shape().dimensions(0) == 1)) {
2955 return UseKind::kUse;
2956 }
2957 }
2958 return UseKind::kReuse;
2959 case HloOpcode::kDynamicUpdateSlice:
2960 // Dynamic-update-slice reuses only start_indices.
2961 if (i == 0 || i == 1) {
2962 return UseKind::kUse;
2963 }
2964 return UseKind::kReuse;
2965 default:
2966 return IsElementwise() ? UseKind::kUse : UseKind::kReuse;
2967 }
2968 }
2969
2970 std::tuple<bool, std::vector<int64>, std::vector<int64>>
ReshapeMerelyInsertsOrDeletes1SizedDimensions() const2971 HloInstruction::ReshapeMerelyInsertsOrDeletes1SizedDimensions() const {
2972 if (HloOpcode::kReshape != opcode_) {
2973 return std::make_tuple(false, std::vector<int64>(), std::vector<int64>());
2974 }
2975 return ShapeUtil::InsertedOrDeleted1SizedDimensions(operand(0)->shape_,
2976 shape_);
2977 }
2978
ToString(HloInstruction::FusionKind kind)2979 string ToString(HloInstruction::FusionKind kind) {
2980 switch (kind) {
2981 case HloInstruction::FusionKind::kLoop:
2982 return "kLoop";
2983 case HloInstruction::FusionKind::kInput:
2984 return "kInput";
2985 case HloInstruction::FusionKind::kOutput:
2986 return "kOutput";
2987 case HloInstruction::FusionKind::kCustom:
2988 return "kCustom";
2989 }
2990 }
2991
StringToFusionKind(const string & kind_name)2992 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
2993 const string& kind_name) {
2994 if (kind_name == "kLoop") {
2995 return HloInstruction::FusionKind::kLoop;
2996 }
2997 if (kind_name == "kInput") {
2998 return HloInstruction::FusionKind::kInput;
2999 }
3000 if (kind_name == "kOutput") {
3001 return HloInstruction::FusionKind::kOutput;
3002 }
3003 if (kind_name == "kCustom") {
3004 return HloInstruction::FusionKind::kCustom;
3005 }
3006 return InvalidArgument("Unknown fusion kind: %s", kind_name);
3007 }
3008
PaddingConfigToString(const PaddingConfig & padding)3009 string PaddingConfigToString(const PaddingConfig& padding) {
3010 bool has_interior_padding =
3011 absl::c_any_of(padding.dimensions(),
3012 [](const PaddingConfig::PaddingConfigDimension& dim) {
3013 return dim.interior_padding() != 0;
3014 });
3015 return StrJoin(
3016 padding.dimensions(), "x",
3017 [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
3018 StrAppend(
3019 out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
3020 has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
3021 });
3022 }
3023
OpMetadataToString(const OpMetadata & metadata)3024 string OpMetadataToString(const OpMetadata& metadata) {
3025 std::vector<string> result;
3026 if (!metadata.op_type().empty()) {
3027 result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\""));
3028 }
3029 if (!metadata.op_name().empty()) {
3030 result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\""));
3031 }
3032 if (!metadata.source_file().empty()) {
3033 result.push_back(
3034 StrCat("source_file=\"", CEscape(metadata.source_file()), "\""));
3035 }
3036 if (metadata.source_line() != 0) {
3037 result.push_back(StrCat("source_line=", metadata.source_line()));
3038 }
3039 return StrJoin(result, " ");
3040 }
3041
RandomDistributionToString(const RandomDistribution & distribution)3042 string RandomDistributionToString(const RandomDistribution& distribution) {
3043 return absl::AsciiStrToLower(RandomDistribution_Name(distribution));
3044 }
3045
PrecisionToString(const PrecisionConfig::Precision & precision)3046 string PrecisionToString(const PrecisionConfig::Precision& precision) {
3047 return absl::AsciiStrToLower(PrecisionConfig::Precision_Name(precision));
3048 }
3049
ConvolutionDimensionNumbersToString(const ConvolutionDimensionNumbers & dnums)3050 string ConvolutionDimensionNumbersToString(
3051 const ConvolutionDimensionNumbers& dnums) {
3052 // lhs_dims[i] is the symbol of the logical dimension i for the lhs
3053 // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b".
3054 std::vector<string> lhs_dims(2 + dnums.input_spatial_dimensions().size());
3055 lhs_dims[dnums.input_batch_dimension()] = 'b';
3056 lhs_dims[dnums.input_feature_dimension()] = 'f';
3057 for (int64 i = 0; i < dnums.input_spatial_dimensions().size(); ++i) {
3058 lhs_dims[dnums.input_spatial_dimensions(i)] = StrCat(i);
3059 }
3060
3061 std::vector<string> rhs_dims(2 + dnums.kernel_spatial_dimensions().size());
3062 rhs_dims[dnums.kernel_input_feature_dimension()] = "i";
3063 rhs_dims[dnums.kernel_output_feature_dimension()] = "o";
3064 for (int64 i = 0; i < dnums.kernel_spatial_dimensions().size(); ++i) {
3065 rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i);
3066 }
3067
3068 std::vector<string> output_dims(2 + dnums.output_spatial_dimensions().size());
3069 output_dims[dnums.output_batch_dimension()] = 'b';
3070 output_dims[dnums.output_feature_dimension()] = 'f';
3071 for (int64 i = 0; i < dnums.output_spatial_dimensions().size(); ++i) {
3072 output_dims[dnums.output_spatial_dimensions(i)] = StrCat(i);
3073 }
3074
3075 return StrCat(StrJoin(lhs_dims, ""), "_", StrJoin(rhs_dims, ""), "->",
3076 StrJoin(output_dims, ""));
3077 }
3078
StringToRandomDistribution(const string & name)3079 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) {
3080 static std::unordered_map<string, RandomDistribution>* map = [] {
3081 static auto* map = new std::unordered_map<string, RandomDistribution>;
3082 for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) {
3083 if (RandomDistribution_IsValid(i)) {
3084 auto value = static_cast<RandomDistribution>(i);
3085 (*map)[RandomDistributionToString(value)] = value;
3086 }
3087 }
3088 return map;
3089 }();
3090 auto found = map->find(absl::AsciiStrToLower(name));
3091 if (found == map->end()) {
3092 return InvalidArgument("Unknown distribution");
3093 }
3094 return found->second;
3095 }
3096
StringToPrecision(const string & name)3097 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name) {
3098 static std::unordered_map<string, PrecisionConfig::Precision>* map = [] {
3099 static auto* map =
3100 new std::unordered_map<string, PrecisionConfig::Precision>;
3101 for (int i = 0; i < PrecisionConfig::Precision_ARRAYSIZE; i++) {
3102 if (PrecisionConfig::Precision_IsValid(i)) {
3103 auto value = static_cast<PrecisionConfig::Precision>(i);
3104 (*map)[PrecisionToString(value)] = value;
3105 }
3106 }
3107 return map;
3108 }();
3109 auto found = map->find(absl::AsciiStrToLower(name));
3110 if (found == map->end()) {
3111 return InvalidArgument("Unknown distribution");
3112 }
3113 return found->second;
3114 }
3115
operator <<(std::ostream & os,HloInstruction::FusionKind kind)3116 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
3117 return os << ToString(kind);
3118 }
3119
operator ()(const HloInstruction * const & lhs,const HloInstruction * const & rhs) const3120 bool HloPtrComparator::operator()(const HloInstruction* const& lhs,
3121 const HloInstruction* const& rhs) const {
3122 if (rhs == nullptr) {
3123 // Nothing compares less than nullptr.
3124 return false;
3125 }
3126 if (lhs == nullptr) {
3127 return true;
3128 }
3129 auto lhs_module = lhs->GetModule();
3130 auto rhs_module = rhs->GetModule();
3131 CHECK((lhs_module == nullptr && rhs_module == nullptr) ||
3132 (lhs_module != nullptr && rhs_module != nullptr));
3133 if (lhs_module != nullptr &&
3134 lhs_module->unique_id() != rhs_module->unique_id()) {
3135 return lhs_module->unique_id() < rhs_module->unique_id();
3136 }
3137 return lhs->unique_id() < rhs->unique_id();
3138 }
3139
CouldBeBitcast() const3140 bool HloInstruction::CouldBeBitcast() const {
3141 switch (opcode_) {
3142 case HloOpcode::kTranspose:
3143 return true;
3144 case HloOpcode::kReshape:
3145 return std::get<0>(ReshapeMerelyInsertsOrDeletes1SizedDimensions());
3146 default:
3147 return false;
3148 }
3149 }
3150
GetBackendConfigInternal(tensorflow::protobuf::Message * proto) const3151 Status HloInstruction::GetBackendConfigInternal(
3152 tensorflow::protobuf::Message* proto) const {
3153 proto->Clear();
3154
3155 // Empty string does not parse as valid JSON, but it's a valid backend config,
3156 // corresponding to the empty proto.
3157 if (backend_config_.empty()) {
3158 return Status::OK();
3159 }
3160 return tensorflow::HumanReadableJsonToProto(backend_config_, proto);
3161 }
3162
set_backend_config(const tensorflow::protobuf::Message & proto)3163 Status HloInstruction::set_backend_config(
3164 const tensorflow::protobuf::Message& proto) {
3165 TF_ASSIGN_OR_RETURN(backend_config_, BackendConfigToRawString(proto));
3166 return Status::OK();
3167 }
3168
BackendConfigToRawString(const tensorflow::protobuf::Message & proto)3169 /* static */ StatusOr<string> HloInstruction::BackendConfigToRawString(
3170 const tensorflow::protobuf::Message& proto) {
3171 string ret;
3172 TF_RETURN_IF_ERROR(tensorflow::ProtoToHumanReadableJson(proto, &ret));
3173 return ret;
3174 }
3175
precision_config() const3176 const PrecisionConfig& HloInstruction::precision_config() const {
3177 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3178 return convolution->precision_config();
3179 }
3180 if (auto* dot = DynCast<HloDotInstruction>(this)) {
3181 return dot->precision_config();
3182 }
3183 LOG(FATAL) << "Unimplemented method.";
3184 }
3185
mutable_precision_config()3186 PrecisionConfig* HloInstruction::mutable_precision_config() {
3187 if (auto* convolution = DynCast<HloConvolutionInstruction>(this)) {
3188 return convolution->mutable_precision_config();
3189 }
3190 if (auto* dot = DynCast<HloDotInstruction>(this)) {
3191 return dot->mutable_precision_config();
3192 }
3193 LOG(FATAL) << "Unimplemented method.";
3194 }
3195
GetModule() const3196 HloModule* HloInstruction::GetModule() const {
3197 if (parent_) {
3198 return parent_->parent();
3199 }
3200 return nullptr;
3201 }
3202
UniquifyName(NameUniquer * name_uniquer)3203 void HloInstruction::UniquifyName(NameUniquer* name_uniquer) {
3204 string parent_str = parent() == nullptr ? "noparent" : parent()->name();
3205 name_ = name_uniquer->GetUniqueName(name_);
3206 }
3207
set_outer_dimension_partitions(const std::vector<int64> & outer_dimension_partitions)3208 void HloInstruction::set_outer_dimension_partitions(
3209 const std::vector<int64>& outer_dimension_partitions) {
3210 outer_dimension_partitions_ = outer_dimension_partitions;
3211 }
3212
3213 // TODO(b/80131774): Remove these temporary methods after transition.
feature_index() const3214 int64 HloInstruction::feature_index() const {
3215 return Cast<HloBatchNormInstruction>(this)->feature_index();
3216 }
3217
epsilon() const3218 float HloInstruction::epsilon() const {
3219 return Cast<HloBatchNormInstruction>(this)->epsilon();
3220 }
3221
fft_type() const3222 FftType HloInstruction::fft_type() const {
3223 return Cast<HloFftInstruction>(this)->fft_type();
3224 }
3225
fft_length() const3226 const std::vector<int64>& HloInstruction::fft_length() const {
3227 return Cast<HloFftInstruction>(this)->fft_length();
3228 }
3229
channel_id() const3230 int64 HloInstruction::channel_id() const {
3231 return Cast<HloSendRecvInstruction>(this)->channel_id();
3232 }
3233
concatenate_dimension() const3234 int64 HloInstruction::concatenate_dimension() const {
3235 return Cast<HloConcatenateInstruction>(this)->concatenate_dimension();
3236 }
3237
dimension() const3238 int64 HloInstruction::dimension() const {
3239 return Cast<HloGetDimensionSizeInstruction>(this)->dimension();
3240 }
3241
IsRank2Transpose() const3242 bool HloInstruction::IsRank2Transpose() const {
3243 auto transpose = DynCast<HloTransposeInstruction>(this);
3244 return transpose != nullptr && transpose->IsRank2Transpose();
3245 }
3246
slice_starts(int64 dimension) const3247 int64 HloInstruction::slice_starts(int64 dimension) const {
3248 return Cast<HloSliceInstruction>(this)->slice_starts(dimension);
3249 }
3250
slice_starts() const3251 const std::vector<int64>& HloInstruction::slice_starts() const {
3252 return Cast<HloSliceInstruction>(this)->slice_starts();
3253 }
3254
slice_limits(int64 dimension) const3255 int64 HloInstruction::slice_limits(int64 dimension) const {
3256 return Cast<HloSliceInstruction>(this)->slice_limits(dimension);
3257 }
3258
slice_limits() const3259 const std::vector<int64>& HloInstruction::slice_limits() const {
3260 return Cast<HloSliceInstruction>(this)->slice_limits();
3261 }
3262
slice_strides(int64 dimension) const3263 int64 HloInstruction::slice_strides(int64 dimension) const {
3264 return Cast<HloSliceInstruction>(this)->slice_strides(dimension);
3265 }
3266
slice_strides() const3267 const std::vector<int64>& HloInstruction::slice_strides() const {
3268 return Cast<HloSliceInstruction>(this)->slice_strides();
3269 }
3270
literal() const3271 const Literal& HloInstruction::literal() const {
3272 return Cast<HloConstantInstruction>(this)->literal();
3273 }
3274
IsConstant() const3275 bool HloInstruction::IsConstant() const {
3276 return DynCast<HloConstantInstruction>(this) != nullptr;
3277 }
3278
RelayoutConstant(const Layout & new_layout,const ShapeIndex & shape_index)3279 void HloInstruction::RelayoutConstant(const Layout& new_layout,
3280 const ShapeIndex& shape_index) {
3281 Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
3282 }
3283
TracingTag() const3284 string HloInstruction::TracingTag() const {
3285 return Cast<HloTraceInstruction>(this)->TracingTag();
3286 }
3287
AddFusionOperand(HloInstruction * new_operand)3288 HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) {
3289 return Cast<HloFusionInstruction>(this)->AddFusionOperand(new_operand);
3290 }
3291
3292 // Delegates to HloFusionInstruction::MergeFusionInstruction.
MergeFusionInstruction(HloInstruction * instruction_to_merge)3293 void HloInstruction::MergeFusionInstruction(
3294 HloInstruction* instruction_to_merge) {
3295 return Cast<HloFusionInstruction>(this)->MergeFusionInstruction(
3296 Cast<HloFusionInstruction>(instruction_to_merge));
3297 }
3298
3299 // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
MergeFusionInstructionIntoMultiOutput(HloInstruction * instruction_to_merge)3300 void HloInstruction::MergeFusionInstructionIntoMultiOutput(
3301 HloInstruction* instruction_to_merge) {
3302 return Cast<HloFusionInstruction>(this)
3303 ->MergeFusionInstructionIntoMultiOutput(
3304 Cast<HloFusionInstruction>(instruction_to_merge));
3305 }
3306
FuseInstruction(HloInstruction * instruction_to_fuse)3307 HloInstruction* HloInstruction::FuseInstruction(
3308 HloInstruction* instruction_to_fuse) {
3309 return Cast<HloFusionInstruction>(this)->FuseInstruction(instruction_to_fuse);
3310 }
3311
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)3312 HloInstruction* HloInstruction::FuseInstructionIntoMultiOutput(
3313 HloInstruction* instruction_to_fuse) {
3314 return Cast<HloFusionInstruction>(this)->FuseInstructionIntoMultiOutput(
3315 instruction_to_fuse);
3316 }
3317
fused_instructions_computation() const3318 HloComputation* HloInstruction::fused_instructions_computation() const {
3319 return Cast<HloFusionInstruction>(this)->fused_instructions_computation();
3320 }
3321
fused_expression_root() const3322 HloInstruction* HloInstruction::fused_expression_root() const {
3323 return Cast<HloFusionInstruction>(this)->fused_expression_root();
3324 }
3325
3326 const tensorflow::gtl::iterator_range<UnwrappingIterator<
3327 std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
fused_instructions() const3328 HloInstruction::fused_instructions() const {
3329 return Cast<HloFusionInstruction>(this)->fused_instructions();
3330 }
3331
3332 const tensorflow::gtl::iterator_range<
3333 UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
fused_instructions()3334 HloInstruction::fused_instructions() {
3335 return Cast<HloFusionInstruction>(this)->fused_instructions();
3336 }
3337
fused_instruction_count() const3338 int64 HloInstruction::fused_instruction_count() const {
3339 return Cast<HloFusionInstruction>(this)->fused_instruction_count();
3340 }
3341
fused_parameter(int64 parameter_number) const3342 HloInstruction* HloInstruction::fused_parameter(int64 parameter_number) const {
3343 return Cast<HloFusionInstruction>(this)->fused_parameter(parameter_number);
3344 }
3345
fused_parameters() const3346 const std::vector<HloInstruction*>& HloInstruction::fused_parameters() const {
3347 return Cast<HloFusionInstruction>(this)->fused_parameters();
3348 }
3349
IsMultiOutputFusion() const3350 const bool HloInstruction::IsMultiOutputFusion() const {
3351 const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(this);
3352 return fusion != nullptr && fusion->IsMultiOutputFusion();
3353 }
3354
fusion_kind() const3355 HloInstruction::FusionKind HloInstruction::fusion_kind() const {
3356 return Cast<HloFusionInstruction>(this)->fusion_kind();
3357 }
3358
set_fusion_kind(FusionKind kind)3359 void HloInstruction::set_fusion_kind(FusionKind kind) {
3360 return Cast<HloFusionInstruction>(this)->set_fusion_kind(kind);
3361 }
3362
random_distribution() const3363 RandomDistribution HloInstruction::random_distribution() const {
3364 return Cast<HloRngInstruction>(this)->random_distribution();
3365 }
3366
parameter_number() const3367 int64 HloInstruction::parameter_number() const {
3368 return Cast<HloParameterInstruction>(this)->parameter_number();
3369 }
3370
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)3371 void HloInstruction::set_parameter_replicated_at_leaf_buffers(
3372 absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
3373 return Cast<HloParameterInstruction>(this)
3374 ->set_parameter_replicated_at_leaf_buffers(
3375 parameter_replicated_at_leaf_buffers);
3376 }
3377
3378 const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers() const3379 HloInstruction::parameter_replicated_at_leaf_buffers() const {
3380 return Cast<HloParameterInstruction>(this)
3381 ->parameter_replicated_at_leaf_buffers();
3382 }
3383
tuple_index() const3384 int64 HloInstruction::tuple_index() const {
3385 return Cast<HloGetTupleElementInstruction>(this)->tuple_index();
3386 }
3387
exponent_bits() const3388 int32 HloInstruction::exponent_bits() const {
3389 return Cast<HloReducePrecisionInstruction>(this)->exponent_bits();
3390 }
3391
mantissa_bits() const3392 int32 HloInstruction::mantissa_bits() const {
3393 return Cast<HloReducePrecisionInstruction>(this)->mantissa_bits();
3394 }
3395
infeed_config() const3396 string HloInstruction::infeed_config() const {
3397 return Cast<HloInfeedInstruction>(this)->infeed_config();
3398 }
3399
set_infeed_config(const string & config)3400 void HloInstruction::set_infeed_config(const string& config) {
3401 return Cast<HloInfeedInstruction>(this)->set_infeed_config(config);
3402 }
3403
outfeed_shape() const3404 const Shape& HloInstruction::outfeed_shape() const {
3405 return Cast<HloOutfeedInstruction>(this)->outfeed_shape();
3406 }
3407
outfeed_config() const3408 const string& HloInstruction::outfeed_config() const {
3409 return Cast<HloOutfeedInstruction>(this)->outfeed_config();
3410 }
3411
replica_groups() const3412 const std::vector<ReplicaGroup>& HloInstruction::replica_groups() const {
3413 return Cast<HloCollectiveInstruction>(this)->replica_groups();
3414 }
3415
3416 const std::vector<std::pair<int64, int64>>&
source_target_pairs() const3417 HloInstruction::source_target_pairs() const {
3418 return Cast<HloCollectivePermuteInstruction>(this)->source_target_pairs();
3419 }
3420
all_reduce_barrier() const3421 string HloInstruction::all_reduce_barrier() const {
3422 return Cast<HloAllReduceInstruction>(this)->all_reduce_barrier();
3423 }
3424
set_all_reduce_barrier(const string & barrier)3425 void HloInstruction::set_all_reduce_barrier(const string& barrier) {
3426 return Cast<HloAllReduceInstruction>(this)->set_all_reduce_barrier(barrier);
3427 }
3428
all_reduce_id() const3429 absl::optional<int64> HloInstruction::all_reduce_id() const {
3430 return Cast<HloAllReduceInstruction>(this)->all_reduce_id();
3431 }
3432
set_all_reduce_id(const absl::optional<int64> & all_reduce_id)3433 void HloInstruction::set_all_reduce_id(
3434 const absl::optional<int64>& all_reduce_id) {
3435 return Cast<HloAllReduceInstruction>(this)->set_all_reduce_id(all_reduce_id);
3436 }
3437
3438 const ConvolutionDimensionNumbers&
convolution_dimension_numbers() const3439 HloInstruction::convolution_dimension_numbers() const {
3440 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3441 return convolution->convolution_dimension_numbers();
3442 }
3443 if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3444 return custom_call->convolution_dimension_numbers();
3445 }
3446 LOG(FATAL) << "Unimplemented method.";
3447 }
3448
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)3449 void HloInstruction::set_convolution_dimension_numbers(
3450 const ConvolutionDimensionNumbers& dnums) {
3451 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3452 convolution->set_convolution_dimension_numbers(dnums);
3453 } else if (auto custom_call = DynCast<HloCustomCallInstruction>(this)) {
3454 custom_call->set_convolution_dimension_numbers(dnums);
3455 } else {
3456 LOG(FATAL) << "Unimplemented method.";
3457 }
3458 }
3459
feature_group_count() const3460 int64 HloInstruction::feature_group_count() const {
3461 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3462 return convolution->feature_group_count();
3463 }
3464 return Cast<HloCustomCallInstruction>(this)->feature_group_count();
3465 }
3466
set_feature_group_count(int64 feature_group_count)3467 void HloInstruction::set_feature_group_count(int64 feature_group_count) {
3468 Cast<HloCustomCallInstruction>(this)->set_feature_group_count(
3469 feature_group_count);
3470 }
3471
batch_group_count() const3472 int64 HloInstruction::batch_group_count() const {
3473 if (auto convolution = DynCast<HloConvolutionInstruction>(this)) {
3474 return convolution->batch_group_count();
3475 }
3476 return Cast<HloCustomCallInstruction>(this)->batch_group_count();
3477 }
3478
set_batch_group_count(int64 batch_group_count)3479 void HloInstruction::set_batch_group_count(int64 batch_group_count) {
3480 Cast<HloCustomCallInstruction>(this)->set_batch_group_count(
3481 batch_group_count);
3482 }
3483
select() const3484 HloComputation* HloInstruction::select() const {
3485 return Cast<HloSelectAndScatterInstruction>(this)->select();
3486 }
3487
scatter() const3488 HloComputation* HloInstruction::scatter() const {
3489 return Cast<HloSelectAndScatterInstruction>(this)->scatter();
3490 }
3491
set_select(HloComputation * computation)3492 void HloInstruction::set_select(HloComputation* computation) {
3493 return Cast<HloSelectAndScatterInstruction>(this)->set_select(computation);
3494 }
3495
set_scatter(HloComputation * computation)3496 void HloInstruction::set_scatter(HloComputation* computation) {
3497 return Cast<HloSelectAndScatterInstruction>(this)->set_scatter(computation);
3498 }
3499
custom_call_target() const3500 const string& HloInstruction::custom_call_target() const {
3501 return Cast<HloCustomCallInstruction>(this)->custom_call_target();
3502 }
3503
padding_config() const3504 const PaddingConfig& HloInstruction::padding_config() const {
3505 return Cast<HloPadInstruction>(this)->padding_config();
3506 }
3507
slice_sizes(int64 dimension) const3508 int64 HloInstruction::slice_sizes(int64 dimension) const {
3509 return Cast<HloDynamicSliceInstruction>(this)->slice_sizes(dimension);
3510 }
3511
dynamic_slice_sizes() const3512 const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const {
3513 return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes();
3514 }
3515
gather_dimension_numbers() const3516 const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const {
3517 return Cast<HloGatherInstruction>(this)->gather_dimension_numbers();
3518 }
3519
gather_slice_sizes() const3520 absl::Span<const int64> HloInstruction::gather_slice_sizes() const {
3521 return Cast<HloGatherInstruction>(this)->gather_slice_sizes();
3522 }
3523
scatter_dimension_numbers() const3524 const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
3525 const {
3526 return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
3527 }
3528
dot_dimension_numbers() const3529 const DotDimensionNumbers& HloInstruction::dot_dimension_numbers() const {
3530 return Cast<HloDotInstruction>(this)->dot_dimension_numbers();
3531 }
3532
operand_side_metadata() const3533 const DomainMetadata& HloInstruction::operand_side_metadata() const {
3534 return Cast<HloDomainInstruction>(this)->operand_side_metadata();
3535 }
3536
user_side_metadata() const3537 const DomainMetadata& HloInstruction::user_side_metadata() const {
3538 return Cast<HloDomainInstruction>(this)->user_side_metadata();
3539 }
3540
comparison_direction() const3541 ComparisonDirection HloInstruction::comparison_direction() const {
3542 return Cast<HloCompareInstruction>(this)->direction();
3543 }
3544
triangular_solve_options() const3545 const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
3546 return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
3547 }
3548
cholesky_options() const3549 const CholeskyOptions& HloInstruction::cholesky_options() const {
3550 return Cast<HloCholeskyInstruction>(this)->cholesky_options();
3551 }
3552
3553 } // namespace xla
3554