1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18 #include <set>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/comparison_util.h"
23 #include "tensorflow/compiler/xla/permutation_util.h"
24 #include "tensorflow/compiler/xla/primitive_util.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/status_macros.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
IsCallerInstruction(HloInstruction * hlo)38 bool IsCallerInstruction(HloInstruction* hlo) {
39 switch (hlo->opcode()) {
40 case HloOpcode::kCall:
41 case HloOpcode::kConditional:
42 case HloOpcode::kWhile:
43 case HloOpcode::kAllReduce:
44 case HloOpcode::kMap:
45 case HloOpcode::kReduce:
46 case HloOpcode::kReduceWindow:
47 case HloOpcode::kScatter:
48 case HloOpcode::kSelectAndScatter:
49 case HloOpcode::kSort:
50 case HloOpcode::kFusion:
51 case HloOpcode::kCustomCall:
52 return true;
53 default:
54 return false;
55 }
56 }
57
58 namespace {
59
CheckOperandCount(const HloInstruction * hlo,int expected)60 Status CheckOperandCount(const HloInstruction* hlo, int expected) {
61 if (hlo->operand_count() != expected) {
62 return InternalError("Expected %d operands for %s instruction: %s",
63 expected, HloOpcodeString(hlo->opcode()),
64 hlo->ToString());
65 }
66 return Status::OK();
67 }
68
CheckParameterCount(const HloInstruction * calling_instruction,const HloComputation * computation,int expected)69 Status CheckParameterCount(const HloInstruction* calling_instruction,
70 const HloComputation* computation, int expected) {
71 if (computation->num_parameters() != expected) {
72 return InternalError(
73 "Expected computation %s called from %s to have %d parameters, has %d",
74 computation->name(), calling_instruction->name(), expected,
75 computation->num_parameters());
76 }
77 return Status::OK();
78 }
79 } // namespace
80
Preprocess(HloInstruction * hlo)81 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
82 if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
83 return InternalError(
84 "Called computations specified for non-caller instruction %s",
85 hlo->ToString());
86 }
87 absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
88 if (arity) {
89 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
90 }
91 return Status::OK();
92 }
93
HandleElementwiseUnary(HloInstruction * hlo)94 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
95 return CheckUnaryShape(hlo);
96 }
97
HandleElementwiseBinary(HloInstruction * hlo)98 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
99 return CheckBinaryShape(hlo);
100 }
101
HandleClamp(HloInstruction * clamp)102 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
103 return CheckTernaryShape(clamp);
104 }
105
HandleSelect(HloInstruction * select)106 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
107 return CheckTernaryShape(select);
108 }
109
HandleTupleSelect(HloInstruction * tuple_select)110 Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) {
111 return CheckTernaryShape(tuple_select);
112 }
113
HandleConcatenate(HloInstruction * concatenate)114 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
115 std::vector<const Shape*> operand_shapes;
116 for (const HloInstruction* operand : concatenate->operands()) {
117 operand_shapes.push_back(&operand->shape());
118 }
119 return CheckShape(concatenate,
120 ShapeInference::InferConcatOpShape(
121 operand_shapes, concatenate->concatenate_dimension()));
122 }
123
HandleConvert(HloInstruction * convert)124 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
125 return CheckShape(convert, ShapeInference::InferConvertShape(
126 convert->operand(0)->shape(),
127 convert->shape().element_type()));
128 }
129
HandleBitcastConvert(HloInstruction * convert)130 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
131 return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
132 convert->operand(0)->shape(),
133 convert->shape().element_type()));
134 }
135
HandleCopy(HloInstruction * copy)136 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
137 return CheckUnaryShape(copy);
138 }
139
HandleDot(HloInstruction * dot)140 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
141 TF_ASSIGN_OR_RETURN(
142 const Shape expected,
143 ShapeInference::InferDotOpShape(
144 dot->operand(0)->shape(), dot->operand(1)->shape(),
145 dot->dot_dimension_numbers(),
146 /*preferred_element_type=*/dot->shape().element_type()));
147 return CheckShape(dot, expected);
148 }
149
HandleConvolution(HloInstruction * convolution)150 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
151 TF_ASSIGN_OR_RETURN(
152 Shape expected,
153 ShapeInference::InferConvolveShape(
154 convolution->operand(0)->shape(), convolution->operand(1)->shape(),
155 convolution->feature_group_count(), convolution->batch_group_count(),
156 convolution->window(), convolution->convolution_dimension_numbers(),
157 /*preferred_element_type=*/convolution->shape().element_type()));
158 return CheckShape(convolution, expected);
159 }
160
HandleFft(HloInstruction * fft)161 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
162 TF_ASSIGN_OR_RETURN(
163 const Shape expected,
164 ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
165 fft->fft_length()));
166 return CheckShape(fft, expected);
167 }
168
HandleTriangularSolve(HloInstruction * hlo)169 Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
170 TF_ASSIGN_OR_RETURN(const Shape expected,
171 ShapeInference::InferTriangularSolveShape(
172 hlo->operand(0)->shape(), hlo->operand(1)->shape(),
173 hlo->triangular_solve_options()));
174 return CheckShape(hlo, expected);
175 }
176
HandleCholesky(HloInstruction * hlo)177 Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
178 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
179 TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
180 hlo->operand(0)->shape()));
181 return CheckShape(hlo, expected);
182 }
183
184 // Checks that `hlo`'s set of ReplicaGroups:
185 //
186 // - names each replica 0 through n-1 exactly once, and
187 // - does not contain any empty ReplicaGroups.
188 //
189 // Note that although none of the groups may be empty, `hlo` is allowed to have
190 // 0 groups. That just means it has one big group.
191 //
192 // This is just a minimal set of checks; some instructions may have additional
193 // requirements. For example, all-to-all requires that all ReplicaGroups have
194 // the same number of replicas, but that isn't checked here.
CheckReplicaGroups(HloInstruction * hlo,bool use_global_device_ids)195 static Status CheckReplicaGroups(HloInstruction* hlo,
196 bool use_global_device_ids) {
197 std::set<int64> replicas_seen;
198 for (const ReplicaGroup& g : hlo->replica_groups()) {
199 if (g.replica_ids().empty()) {
200 return InternalError("Instruction cannot have an empty replica group: %s",
201 hlo->ToString());
202 }
203 for (int64 i : g.replica_ids()) {
204 if (!replicas_seen.insert(i).second) {
205 return InternalError(
206 "Replica %d is repeated in instruction's replica-groups: %s", i,
207 hlo->ToString());
208 }
209 }
210 }
211 for (int64 i = 0; i < replicas_seen.size(); ++i) {
212 if (!replicas_seen.count(i)) {
213 return InternalError(
214 "Replica %d is not named in instruction's replica-groups: %s", i,
215 hlo->ToString());
216 }
217 }
218
219 // If use_global_device_ids() is set, replica_groups cannot be empty.
220 // When the channel_id() or use_global_device_ids() is set, device ids in
221 // ReplicaGroup config no longer only mean replica ids. So we skip the check
222 // on the replica count.
223 if (use_global_device_ids) {
224 if (hlo->replica_groups().empty()) {
225 return InternalError(
226 "Replica group must be specified when use_global_device_ids is true");
227 }
228 // No need to check replica_count.
229 return Status::OK();
230 }
231
232 if (auto channel_instr = DynCast<HloChannelInstruction>(hlo)) {
233 if (channel_instr->channel_id()) {
234 return Status::OK();
235 }
236 }
237
238 int64 replica_count = hlo->GetModule()->config().replica_count();
239 if (replica_count != 1 && !replicas_seen.empty() &&
240 replicas_seen.size() != replica_count) {
241 return InternalError(
242 "Replica count in HloModuleConfig is %d, but ReplicaGroup config "
243 "contains %d replicas: %s",
244 replica_count, replicas_seen.size(), hlo->ToString());
245 }
246
247 return Status::OK();
248 }
249
HandleAllGather(HloInstruction * hlo)250 Status ShapeVerifier::HandleAllGather(HloInstruction* hlo) {
251 auto ag = Cast<HloAllGatherInstruction>(hlo);
252 TF_RETURN_IF_ERROR(CheckReplicaGroups(ag, ag->use_global_device_ids()));
253 TF_RET_CHECK(ag->all_gather_dimension() >= 0);
254 TF_RET_CHECK(ag->all_gather_dimension() < ag->shape().rank());
255 TF_RET_CHECK(ag->all_gather_dimension() < ag->operand(0)->shape().rank());
256
257 int64 shard_count = CeilOfRatio(
258 ag->shape().dimensions(ag->all_gather_dimension()),
259 ag->operand(0)->shape().dimensions(ag->all_gather_dimension()));
260 if (ag->channel_id().has_value()) {
261 if (ag->use_global_device_ids()) {
262 TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
263 } else {
264 if (ag->replica_groups().empty() ||
265 ag->replica_groups()[0].replica_ids_size() != 1) {
266 return InternalError(
267 "Replica group size must be 1 when use_global_device_ids is "
268 "false if the all-gather is also cross-partition");
269 }
270 }
271 } else if (!ag->replica_groups().empty()) {
272 // Cross-replica all-gather: shard count is subgroup size.
273 TF_RET_CHECK(shard_count == ag->replica_groups()[0].replica_ids_size());
274 }
275 return CheckShape(ag, ShapeInference::InferAllGatherShape(
276 ag->operand(0)->shape(), ag->all_gather_dimension(),
277 shard_count));
278 }
279
HandleAllReduce(HloInstruction * hlo)280 Status ShapeVerifier::HandleAllReduce(HloInstruction* hlo) {
281 auto ar = Cast<HloAllReduceInstruction>(hlo);
282 TF_RETURN_IF_ERROR(CheckReplicaGroups(ar, ar->use_global_device_ids()));
283
284 std::vector<const Shape*> operand_shapes;
285 for (const HloInstruction* operand : hlo->operands()) {
286 operand_shapes.push_back(&operand->shape());
287 }
288 return CheckShape(hlo, ShapeInference::InferAllReduceShape(operand_shapes));
289 }
290
HandleAllToAll(HloInstruction * hlo)291 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
292 auto* all_to_all = Cast<HloAllToAllInstruction>(hlo);
293 TF_RETURN_IF_ERROR(CheckReplicaGroups(hlo, /*use_global_device_ids=*/false));
294
295 TF_RET_CHECK(all_to_all != nullptr);
296 if (all_to_all->split_dimension()) {
297 if (hlo->replica_groups().empty()) {
298 return InternalError(
299 "An array all-to-all must have an explicit replica_groups config");
300 }
301 }
302
303 // The size of each replica group must be the same (the split count of the
304 // operaion). In case the default replica group is used (empty replica group,
305 // must not be an array all-to-all, as checked above), infer from the number
306 // of operands.
307 const int64 split_count = hlo->replica_groups().empty()
308 ? hlo->operand_count()
309 : hlo->replica_groups()[0].replica_ids_size();
310 for (const ReplicaGroup& g : hlo->replica_groups()) {
311 if (g.replica_ids_size() != split_count) {
312 return InternalError(
313 "Replica group has size %d, but all replica groups in an all-to-all "
314 "must have size N: %s",
315 g.replica_ids_size(), hlo->ToString());
316 }
317 }
318
319 if (all_to_all->split_dimension()) {
320 TF_RET_CHECK(hlo->operand_count() == 1);
321 return CheckShape(
322 hlo, ShapeInference::InferAllToAllShape(
323 hlo->operand(0)->shape(), *all_to_all->split_dimension(),
324 *all_to_all->split_dimension(), split_count));
325 } else {
326 std::vector<const Shape*> operand_shapes;
327 for (const HloInstruction* operand : hlo->operands()) {
328 operand_shapes.push_back(&operand->shape());
329 }
330 return CheckShape(hlo,
331 ShapeInference::InferAllToAllTupleShape(operand_shapes));
332 }
333 }
334
HandlePartitionId(HloInstruction * hlo)335 Status ShapeVerifier::HandlePartitionId(HloInstruction* hlo) {
336 return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
337 }
338
HandleReplicaId(HloInstruction * hlo)339 Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
340 return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
341 }
342
343 namespace {
344
CheckDuplicatedSourceOrTarget(HloInstruction * hlo)345 Status CheckDuplicatedSourceOrTarget(HloInstruction* hlo) {
346 // A source or target cannot appear twice in the collective-permute's
347 // source-target pairs.
348 absl::flat_hash_set<int64> seen_sources;
349 absl::flat_hash_set<int64> seen_targets;
350 for (const auto& p : hlo->source_target_pairs()) {
351 if (!seen_sources.insert(p.first).second) {
352 return InternalError(
353 "Source %d appears more than once in instruction's source-target "
354 "pairs: %s",
355 p.first, hlo->ToString());
356 }
357 if (!seen_targets.insert(p.second).second) {
358 return InternalError(
359 "Target %d appears more than once in instruction's source-target "
360 "pairs: %s",
361 p.second, hlo->ToString());
362 }
363 }
364 return Status::OK();
365 }
366
367 } // namespace
368
HandleCollectivePermute(HloInstruction * hlo)369 Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
370 TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
371 return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
372 hlo->operand(0)->shape()));
373 }
374
HandleCollectivePermuteStart(HloInstruction * hlo)375 Status ShapeVerifier::HandleCollectivePermuteStart(HloInstruction* hlo) {
376 TF_RETURN_IF_ERROR(CheckDuplicatedSourceOrTarget(hlo));
377 return CheckShape(
378 hlo, ShapeUtil::MakeTupleShape(
379 {hlo->operand(0)->shape(), hlo->operand(0)->shape(),
380 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeShape(U32, {})}));
381 }
382
HandleCollectivePermuteDone(HloInstruction * hlo)383 Status ShapeVerifier::HandleCollectivePermuteDone(HloInstruction* hlo) {
384 return CheckShape(
385 hlo, ShapeUtil::GetTupleElementShape(hlo->operand(0)->shape(), 0));
386 }
387
HandleReducePrecision(HloInstruction * reduce_precision)388 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
389 return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
390 reduce_precision->operand(0)->shape(),
391 reduce_precision->exponent_bits(),
392 reduce_precision->mantissa_bits()));
393 }
394
CheckIsTokenOperand(const HloInstruction * instruction,int64 operand_no)395 Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
396 int64 operand_no) {
397 const HloInstruction* token = instruction->operand(operand_no);
398 if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
399 return InternalError(
400 "Expected operand %d to be token-shaped, actual shape is "
401 "%s:\n%s",
402 operand_no, StringifyShape(token->shape()), instruction->ToString());
403 }
404 return Status::OK();
405 }
406
CheckOperandAndParameter(const HloInstruction * instruction,int64 operand_number,const HloComputation * computation,int64 parameter_number)407 Status ShapeVerifier::CheckOperandAndParameter(
408 const HloInstruction* instruction, int64 operand_number,
409 const HloComputation* computation, int64 parameter_number) {
410 const HloInstruction* operand = instruction->operand(operand_number);
411 const HloInstruction* parameter =
412 computation->parameter_instruction(parameter_number);
413 if (!ShapesSame(operand->shape(), parameter->shape())) {
414 return InternalError("Operand %s shape does not match parameter's %s in %s",
415 operand->ToString(), parameter->ToString(),
416 instruction->ToString());
417 }
418 return Status::OK();
419 }
420
HandleInfeed(HloInstruction * instruction)421 Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
422 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
423 TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
424
425 // The output of infeed is a tuple containing the data value and a token.
426 return CheckShape(infeed,
427 ShapeUtil::MakeTupleShape(
428 {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
429 }
430
HandleOutfeed(HloInstruction * instruction)431 Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
432 HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
433 TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
434
435 // Outfeed has a separate shape field for the value which is outfed to the
436 // host. The shape of the instruction itself is always a token.
437 if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
438 return InternalError(
439 "Expected outfeed shape to be equal to operand's shape %s, "
440 "actual shape is %s:\n%s",
441 StringifyShape(outfeed->operand(0)->shape()),
442 StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
443 }
444 return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
445 }
446
HasCompatibleElementTypes(const Shape & shape_0,const Shape & shape_1,const Shape & result_shape)447 bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
448 const Shape& shape_1,
449 const Shape& result_shape) {
450 return ShapeUtil::SameElementType(shape_0, shape_1) &&
451 (ShapeUtil::SameElementType(shape_0, result_shape) ||
452 (allow_mixed_precision_ &&
453 ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
454 result_shape)));
455 }
456
HandleRng(HloInstruction * instruction)457 Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
458 TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
459
460 const Shape& shape_0 = instruction->operand(0)->shape();
461 const Shape& shape_1 = instruction->operand(1)->shape();
462 if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
463 return InternalError(
464 "Expected scalar types for the two operands of Rng instruction: %s",
465 instruction->ToString());
466 }
467
468 if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
469 return InternalError(
470 "Expected compatible element types for the result and the two operands"
471 " of Rng instruction: %s",
472 instruction->ToString());
473 }
474
475 PrimitiveType element_type = shape_0.element_type();
476 switch (instruction->random_distribution()) {
477 case RNG_UNIFORM:
478 if (!primitive_util::IsFloatingPointType(element_type) &&
479 !primitive_util::IsIntegralType(element_type) &&
480 element_type != PRED) {
481 return InternalError(
482 "Element type not supported."
483 " Expected element to be of floating point type, integral type or"
484 " predicate type for RngUniform: %s",
485 instruction->ToString());
486 }
487 break;
488
489 case RNG_NORMAL:
490 if (!primitive_util::IsFloatingPointType(element_type)) {
491 return InternalError(
492 "Element type not supported."
493 " Expected element to be FloatingPointType for RngNormal: %s",
494 instruction->ToString());
495 }
496 break;
497 default:
498 return InternalError(
499 "Invalid Rng distribution %s",
500 RandomDistribution_Name(instruction->random_distribution()));
501 }
502
503 return Status::OK();
504 }
505
HandleRngBitGenerator(HloInstruction * hlo)506 Status ShapeVerifier::HandleRngBitGenerator(HloInstruction* hlo) {
507 if (!hlo->shape().IsTuple() || hlo->shape().tuple_shapes_size() != 2) {
508 return InternalError(
509 "Expected tuple shape with 2 elements for RngBitGenerator. Got: %s",
510 hlo->shape().ToString());
511 }
512 if (!ShapeUtil::Compatible(hlo->operand(0)->shape(),
513 hlo->shape().tuple_shapes(0))) {
514 return InternalError(
515 "Expected state shape to match between input and output for "
516 "RngBitGenerator. Got %s vs. %s",
517 hlo->operand(0)->shape().ToString(),
518 hlo->shape().tuple_shapes(0).ToString());
519 }
520 return Status::OK();
521 }
522
HandleRngGetAndUpdateState(HloInstruction * instruction)523 Status ShapeVerifier::HandleRngGetAndUpdateState(HloInstruction* instruction) {
524 TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
525 const Shape& result_shape = instruction->shape();
526 const Shape expected_shape = ShapeUtil::MakeShape(U64, {2});
527 if (!ShapeUtil::Compatible(result_shape, expected_shape)) {
528 return InternalError(
529 "Invalid RngGetAndUpdateState, expect result to have shape %s, got %s ",
530 StringifyShape(expected_shape), StringifyShape(result_shape));
531 }
532
533 return Status::OK();
534 }
535
HandleReverse(HloInstruction * reverse)536 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
537 return CheckShape(
538 reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
539 reverse->dimensions()));
540 }
541
HandleSort(HloInstruction * sort)542 Status ShapeVerifier::HandleSort(HloInstruction* sort) {
543 if (sort->operand_count() < 1) {
544 return InternalError("Expected at least 1 operand for %s instruction: %s",
545 HloOpcodeString(sort->opcode()), sort->ToString());
546 }
547 HloComputation* compare = sort->to_apply();
548
549 // Check that the 'compare' computation returns a PRED.
550 Shape compare_shape = compare->root_instruction()->shape();
551 if (!ShapeUtil::Compatible(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
552 return InternalError(
553 "The Sort compare computation shape does not lead to a scalar "
554 "predicate shape: %s",
555 StringifyShape(compare_shape));
556 }
557
558 // Check that the number of parameters of the 'compare' computation is
559 // correct.
560 TF_RETURN_IF_ERROR(
561 CheckParameterCount(sort, compare, sort->operand_count() * 2));
562
563 // Verify that the operands of the compare computation have the correct scalar
564 // shapes.
565 for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters();
566 ++parameter_idx) {
567 int64 operand_idx = parameter_idx / 2;
568 Shape expected_scalar_shape = ShapeUtil::MakeShape(
569 sort->operand(operand_idx)->shape().element_type(), {});
570 Shape actual_parameter_shape =
571 compare->parameter_instruction(parameter_idx)->shape();
572 if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
573 actual_parameter_shape)) {
574 return InternalError(
575 "Expected the %lld-th parameter of the compare computation of sort "
576 "to have shape %s, but got %s",
577 parameter_idx, StringifyShape(expected_scalar_shape),
578 StringifyShape(actual_parameter_shape));
579 }
580 }
581
582 // Verify that all operand shapes have the same dimensions.
583 for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
584 if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
585 sort->operand(operand)->shape())) {
586 return InternalError(
587 "Expected sort to have to have the same dimensions for all operands. "
588 "First operand shape is: %s\n, shape (operand index %lld) is: %s",
589 StringifyShape(sort->operand(0)->shape()), operand,
590 StringifyShape(sort->operand(operand)->shape()));
591 }
592 }
593 return CheckVariadicShape(sort);
594 }
595
HandleConstant(HloInstruction * constant)596 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
597 if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
598 return InternalError("Constant is required to have a valid literal: %s",
599 constant->ToString());
600 }
601 return CheckShape(constant, constant->literal().shape(),
602 /*only_compare_minor_to_major_in_layout=*/true);
603 }
604
HandleIota(HloInstruction * hlo)605 Status ShapeVerifier::HandleIota(HloInstruction* hlo) {
606 auto* iota = Cast<HloIotaInstruction>(hlo);
607 if (!iota->shape().IsArray()) {
608 return InternalError("Iota does not support non-array result.");
609 }
610 const int64 rank = iota->shape().rank();
611 if (rank == 0) {
612 return InternalError("Iota does not support scalars.");
613 }
614 int64 iota_dimension = iota->iota_dimension();
615 if (iota_dimension >= rank || iota_dimension < 0) {
616 return InternalError(
617 "The iota dimension cannot go beyond the operation rank or be "
618 "negative.");
619 }
620
621 PrimitiveType primitive_type = iota->shape().element_type();
622 if (!primitive_util::IsIntegralType(primitive_type) &&
623 !primitive_util::IsFloatingPointType(primitive_type) &&
624 !primitive_util::IsComplexType(primitive_type)) {
625 return InvalidArgument(
626 "Only support iota of integral, floating point or complex primitive "
627 "types, got %s",
628 PrimitiveType_Name(primitive_type));
629 }
630
631 return Status::OK();
632 }
633
HandleGetTupleElement(HloInstruction * get_tuple_element)634 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
635 return CheckShape(get_tuple_element,
636 ShapeInference::InferGetTupleElementShape(
637 get_tuple_element->operand(0)->shape(),
638 get_tuple_element->tuple_index()));
639 }
640
641 namespace {
SameElementTypesForOperandsAndToApplyParameters(const HloInstruction & instruction,int64 num_operands_to_check)642 Status SameElementTypesForOperandsAndToApplyParameters(
643 const HloInstruction& instruction, int64 num_operands_to_check) {
644 const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
645 for (int i = 0; i < num_operands_to_check; ++i) {
646 const Shape& parameter_shape = to_apply.parameters(i);
647 const Shape& operand_shape = instruction.operands()[i]->shape();
648 if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
649 return InvalidArgument(
650 "Shape mismatch between to_apply computation"
651 " parameter and operand %d in %s.",
652 i, instruction.ToString().c_str());
653 }
654 }
655 return Status::OK();
656 }
657 } // namespace
658
HandleReduce(HloInstruction * reduce)659 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
660 if (reduce->operand_count() % 2 != 0) {
661 return InternalError(
662 "Expected an even number of operands for %s instruction: %s",
663 HloOpcodeString(reduce->opcode()), reduce->ToString());
664 }
665
666 std::vector<const Shape*> operand_shapes;
667 for (const HloInstruction* operand : reduce->operands()) {
668 operand_shapes.push_back(&operand->shape());
669 }
670 TF_RETURN_IF_ERROR(
671 CheckShape(reduce, ShapeInference::InferReduceShape(
672 operand_shapes, reduce->dimensions(),
673 reduce->to_apply()->ComputeProgramShape())));
674
675 return allow_mixed_precision_
676 ? Status::OK()
677 : SameElementTypesForOperandsAndToApplyParameters(
678 *reduce, reduce->operands().size() - 1);
679 }
680
HandleBitcast(HloInstruction * bitcast)681 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
682 if (layout_sensitive_ &&
683 shape_size_function_(bitcast->shape()) !=
684 shape_size_function_(bitcast->operand(0)->shape())) {
685 return InternalError(
686 "Bitcast cannot have different shape sizes of output (%d) and operand "
687 "(%d) (%s) (%s)",
688 shape_size_function_(bitcast->shape()),
689 shape_size_function_(bitcast->operand(0)->shape()),
690 bitcast->shape().ToString(true),
691 bitcast->operand(0)->shape().ToString(true));
692 }
693 return Status::OK();
694 }
695
HandleBroadcast(HloInstruction * broadcast)696 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
697 // HLO broadcast has no exact analog at the proto level so there is no
698 // ShapeInference method. Check the output shape explicitly.
699 const Shape& operand_shape = broadcast->operand(0)->shape();
700 // Check for mixed precision.
701 TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
702 TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
703 for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
704 ++operand_dimension) {
705 int64 output_dimension = broadcast->dimensions()[operand_dimension];
706 TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
707 output_dimension >= 0 &&
708 (broadcast->shape().dimensions(output_dimension) ==
709 operand_shape.dimensions(operand_dimension)))
710 << broadcast->ToString() << " operand shape " << operand_shape;
711 }
712 return Status::OK();
713 }
714
HandleDynamicReshape(HloInstruction * dynamic_reshape)715 Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) {
716 // Check for mixed precision.
717 const Shape& operand_shape = dynamic_reshape->operand(0)->shape();
718 TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape));
719 TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) ==
720 ShapeUtil::ElementsIn(operand_shape));
721 TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 ==
722 dynamic_reshape->operand_count());
723 for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) {
724 TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32);
725 }
726 return Status::OK();
727 }
728
HandleReshape(HloInstruction * reshape)729 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
730 // Check for mixed precision.
731 const Shape& operand_shape = reshape->operand(0)->shape();
732 TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
733 TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
734 ShapeUtil::ElementsIn(operand_shape));
735 return Status::OK();
736 }
737
HandleTranspose(HloInstruction * transpose)738 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
739 return CheckShape(
740 transpose, ShapeInference::InferTransposeShape(
741 transpose->operand(0)->shape(), transpose->dimensions()));
742 }
743
HandleParameter(HloInstruction * hlo)744 Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
745 return Status::OK();
746 }
747
HandleFusion(HloInstruction * fusion)748 Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
749 if (fusion->called_computations().size() != 1) {
750 return InternalError(
751 "Fusion has a non-unary number of called computations (%s)",
752 fusion->ToString().c_str());
753 }
754 const Shape& root_computation_shape =
755 fusion->called_computations()[0]->root_instruction()->shape();
756 if (!ShapesSame(fusion->shape(), root_computation_shape)) {
757 return InternalError(
758 "Fused computation shape (%s) is not equal to the fusion shape (%s)",
759 root_computation_shape.ToString(true), fusion->shape().ToString(true));
760 }
761
762 auto& fused_parameters = fusion->fused_parameters();
763 if (fused_parameters.size() != fusion->operand_count()) {
764 return InternalError(
765 "Fused parameter count (%d) does not match the number of operands (%d)"
766 " passed to the fusion instruction in: %s.",
767 fused_parameters.size(), fusion->operand_count(),
768 fusion->ToString().c_str());
769 }
770 for (HloInstruction* fused_param : fused_parameters) {
771 int64 param_no = fused_param->parameter_number();
772 if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
773 return InternalError(
774 "Shape mismatch between parameter number %d and its operand in "
775 "%s.",
776 param_no, fusion->ToString().c_str());
777 }
778 }
779 return Status::OK();
780 }
781
HandleCall(HloInstruction * call)782 Status ShapeVerifier::HandleCall(HloInstruction* call) {
783 TF_RETURN_IF_ERROR(
784 CheckParameterCount(call, call->to_apply(), call->operand_count()));
785 for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
786 TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
787 }
788 // The shape of kCall should match the shape of the computation it calls.
789 return CheckShape(call, call->to_apply()->root_instruction()->shape());
790 }
791
HandleCustomCall(HloInstruction * instruction)792 Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
793 const HloCustomCallInstruction* custom_call =
794 DynCast<const HloCustomCallInstruction>(instruction);
795 TF_RET_CHECK(custom_call != nullptr);
796 if (custom_call->layout_constrained()) {
797 // If the layout is constrained, verify all the respective shapes have
798 // layouts and that the constrained operand shapes match the shapes of the
799 // operands.
800 TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
801 TF_RET_CHECK(custom_call->operand_count() ==
802 custom_call->operand_shapes_with_layout().size());
803 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
804 const Shape& operand_shape_with_layout =
805 custom_call->operand_shapes_with_layout()[i];
806 TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
807 operand_shape_with_layout))
808 << custom_call->operand(i)->shape().ToString() << " operand "
809 << operand_shape_with_layout.ToString();
810 TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
811 }
812 }
813 for (const auto& pair : custom_call->output_to_operand_aliasing()) {
814 TF_RET_CHECK(pair.second.first < custom_call->operand_count())
815 << "Invalid aliasing operand index.";
816 TF_RET_CHECK(ShapeUtil::IndexIsValid(
817 custom_call->operand(pair.second.first)->shape(), pair.second.second))
818 << "Invalid aliasing operand shape index.";
819 TF_RET_CHECK(ShapeUtil::IndexIsValid(custom_call->shape(), pair.first))
820 << "Invalid aliasing output shape index.";
821 const Shape& output_subshape =
822 ShapeUtil::GetSubshape(custom_call->shape(), pair.first);
823 const Shape& operand_subshape = ShapeUtil::GetSubshape(
824 custom_call->operand(pair.second.first)->shape(), pair.second.second);
825 if (layout_sensitive_) {
826 TF_RET_CHECK(operand_subshape == output_subshape)
827 << "Different aliasing shapes: " << operand_subshape.ToString()
828 << " vs " << output_subshape.ToString();
829 } else {
830 TF_RET_CHECK(ShapeUtil::Compatible(output_subshape, operand_subshape))
831 << "Different aliasing shapes: " << operand_subshape.ToString()
832 << " vs " << output_subshape.ToString();
833 }
834 }
835 return Status::OK();
836 }
837
HandleSlice(HloInstruction * slice)838 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
839 return CheckShape(slice,
840 ShapeInference::InferSliceShape(
841 slice->operand(0)->shape(), slice->slice_starts(),
842 slice->slice_limits(), slice->slice_strides()));
843 }
844
HandleDynamicSlice(HloInstruction * dynamic_slice)845 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
846 return CheckShape(
847 dynamic_slice,
848 ShapeInference::InferDynamicSliceShape(
849 dynamic_slice->operand(0)->shape(),
850 Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
851 dynamic_slice->dynamic_slice_sizes()));
852 }
853
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)854 Status ShapeVerifier::HandleDynamicUpdateSlice(
855 HloInstruction* dynamic_update_slice) {
856 return CheckShape(
857 dynamic_update_slice,
858 ShapeInference::InferDynamicUpdateSliceShape(
859 dynamic_update_slice->operand(0)->shape(),
860 dynamic_update_slice->operand(1)->shape(),
861 Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
862 ->index_shapes()));
863 }
864
HandleTuple(HloInstruction * tuple)865 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
866 return CheckVariadicShape(tuple);
867 }
868
HandleMap(HloInstruction * map)869 Status ShapeVerifier::HandleMap(HloInstruction* map) {
870 std::vector<const Shape*> operand_shapes;
871 int64 max_operand_rank = 0;
872 for (const HloInstruction* operand : map->operands()) {
873 operand_shapes.push_back(&operand->shape());
874 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
875 }
876 // TODO(b/65689298) Remove code below once Map is generalized to accept
877 // arbitrary map dimensions.
878 std::vector<int64> map_dims(max_operand_rank);
879 std::iota(map_dims.begin(), map_dims.end(), 0);
880
881 TF_RETURN_IF_ERROR(CheckShape(
882 map,
883 ShapeInference::InferMapShape(
884 operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
885
886 return allow_mixed_precision_
887 ? Status::OK()
888 : SameElementTypesForOperandsAndToApplyParameters(
889 *map, map->operands().size());
890 }
891
HandleReduceWindow(HloInstruction * reduce_window)892 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
893 VLOG(2) << "Verify reduce window:" << reduce_window->ToString() << "\n";
894 auto reduce_window_instr = Cast<HloReduceWindowInstruction>(reduce_window);
895 auto input_shapes = reduce_window_instr->input_array_shapes();
896 VLOG(2) << "reduce window input shape count: " << input_shapes.size() << "\n";
897 auto init_shapes = reduce_window_instr->init_value_shapes();
898 VLOG(2) << "reduce instruction is :" << reduce_window->ToString() << "\n";
899 TF_RETURN_IF_ERROR(CheckShape(
900 reduce_window, ShapeInference::InferReduceWindowShape(
901 input_shapes, init_shapes, reduce_window->window(),
902 reduce_window->to_apply()->ComputeProgramShape())));
903
904 return allow_mixed_precision_
905 ? Status::OK()
906 : SameElementTypesForOperandsAndToApplyParameters(*reduce_window,
907 1);
908 }
909
HandleSelectAndScatter(HloInstruction * instruction)910 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
911 return CheckShape(
912 instruction,
913 ShapeInference::InferSelectAndScatterShape(
914 instruction->operand(0)->shape(),
915 instruction->select()->ComputeProgramShape(), instruction->window(),
916 instruction->operand(1)->shape(), instruction->operand(2)->shape(),
917 instruction->scatter()->ComputeProgramShape()));
918 }
919
HandleWhile(HloInstruction * xla_while)920 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
921 TF_RETURN_IF_ERROR(
922 CheckParameterCount(xla_while, xla_while->while_body(), 1));
923 TF_RETURN_IF_ERROR(
924 CheckParameterCount(xla_while, xla_while->while_condition(), 1));
925 TF_RETURN_IF_ERROR(
926 CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
927 TF_RETURN_IF_ERROR(
928 CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
929 const Shape& conditional_shape =
930 xla_while->while_condition()->root_instruction()->shape();
931 if (!ShapeUtil::Compatible(conditional_shape,
932 ShapeUtil::MakeShape(PRED, {}))) {
933 return InternalError(
934 "Conditional computation shape does not lead to a scalar predicate "
935 "shape: %s",
936 StringifyShape(conditional_shape));
937 }
938 // The shape of kWhile should match the shape of the body computation it
939 // calls.
940 return CheckShape(xla_while,
941 xla_while->while_body()->root_instruction()->shape());
942 }
943
HandleConditional(HloInstruction * conditional)944 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
945 if (!ShapeUtil::IsScalar(conditional->operand(0)->shape())) {
946 return InvalidArgument(
947 "The first operand of conditional must be a scalar. Got %s",
948 conditional->operand(0)->shape().DebugString());
949 }
950 const int num_branches = conditional->branch_count();
951 PrimitiveType operand0_type = conditional->operand(0)->shape().element_type();
952 if (operand0_type == PRED) {
953 TF_RET_CHECK(num_branches == 2);
954 } else {
955 if (operand0_type != S32) {
956 return InvalidArgument(
957 "The first operand of indexed conditional must be a scalar of S32. "
958 "Got"
959 " type %s.",
960 PrimitiveType_Name(operand0_type));
961 }
962 TF_RET_CHECK(num_branches >= 1);
963 }
964 TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
965 for (int j = 0; j < num_branches; ++j) {
966 TF_RETURN_IF_ERROR(CheckParameterCount(
967 conditional, conditional->branch_computation(j), 1));
968 TF_RETURN_IF_ERROR(CheckOperandAndParameter(
969 conditional, j + 1, conditional->branch_computation(j), 0));
970 TF_RETURN_IF_ERROR(CheckShape(
971 conditional,
972 conditional->branch_computation(j)->root_instruction()->shape()));
973 }
974 return Status::OK();
975 }
976
HandlePad(HloInstruction * pad)977 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
978 return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
979 pad->operand(1)->shape(),
980 pad->padding_config()));
981 }
982
HandleCopyStart(HloInstruction * copy_start)983 Status ShapeVerifier::HandleCopyStart(HloInstruction* copy_start) {
984 return CheckShape(copy_start,
985 ShapeUtil::MakeTupleShape({copy_start->operand(0)->shape(),
986 copy_start->operand(0)->shape(),
987 ShapeUtil::MakeShape(U32, {})}),
988 /*only_compare_minor_to_major_in_layout=*/true);
989 }
990
HandleCopyDone(HloInstruction * copy_done)991 Status ShapeVerifier::HandleCopyDone(HloInstruction* copy_done) {
992 const Shape& operand_shape = copy_done->operand(0)->shape();
993 const Shape& dest_shape = ShapeUtil::GetTupleElementShape(operand_shape, 0);
994 const Shape& src_shape = ShapeUtil::GetTupleElementShape(operand_shape, 1);
995 if (!ShapesSame(dest_shape, src_shape,
996 /*minor_to_major_only=*/false,
997 /*ignore_memory_space=*/true)) {
998 return InternalError(
999 "Source and destination buffers in CopyDone arguments need to be the "
1000 "same shape found %s and %s\n%s",
1001 StringifyShape(dest_shape), StringifyShape(src_shape),
1002 copy_done->ToString());
1003 }
1004 return CheckShape(copy_done, ShapeUtil::GetTupleElementShape(
1005 copy_done->operand(0)->shape(), 0));
1006 }
1007
HandleSend(HloInstruction * send)1008 Status ShapeVerifier::HandleSend(HloInstruction* send) {
1009 return CheckShape(send,
1010 ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
1011 ShapeUtil::MakeShape(U32, {}),
1012 ShapeUtil::MakeTokenShape()}),
1013 /*only_compare_minor_to_major_in_layout=*/true);
1014 }
1015
HandleSendDone(HloInstruction * send_done)1016 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
1017 return CheckShape(send_done, ShapeUtil::MakeTokenShape());
1018 }
1019
HandleRecv(HloInstruction * recv)1020 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
1021 return CheckShape(
1022 recv,
1023 ShapeUtil::MakeTupleShape(
1024 {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
1025 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}),
1026 /*only_compare_minor_to_major_in_layout=*/true);
1027 }
1028
HandleRecvDone(HloInstruction * recv_done)1029 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
1030 return CheckShape(
1031 recv_done,
1032 ShapeUtil::MakeTupleShape(
1033 {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
1034 ShapeUtil::MakeTokenShape()}));
1035 }
1036
HandleBatchNormTraining(HloInstruction * batch_norm_training)1037 Status ShapeVerifier::HandleBatchNormTraining(
1038 HloInstruction* batch_norm_training) {
1039 return CheckShape(batch_norm_training,
1040 ShapeInference::InferBatchNormTrainingShape(
1041 batch_norm_training->operand(0)->shape(),
1042 batch_norm_training->operand(1)->shape(),
1043 batch_norm_training->operand(2)->shape(),
1044 batch_norm_training->feature_index()));
1045 }
1046
HandleBatchNormInference(HloInstruction * batch_norm_inference)1047 Status ShapeVerifier::HandleBatchNormInference(
1048 HloInstruction* batch_norm_inference) {
1049 return CheckShape(batch_norm_inference,
1050 ShapeInference::InferBatchNormInferenceShape(
1051 batch_norm_inference->operand(0)->shape(),
1052 batch_norm_inference->operand(1)->shape(),
1053 batch_norm_inference->operand(2)->shape(),
1054 batch_norm_inference->operand(3)->shape(),
1055 batch_norm_inference->operand(4)->shape(),
1056 batch_norm_inference->feature_index()));
1057 }
1058
HandleBatchNormGrad(HloInstruction * batch_norm_grad)1059 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
1060 return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
1061 batch_norm_grad->operand(0)->shape(),
1062 batch_norm_grad->operand(1)->shape(),
1063 batch_norm_grad->operand(2)->shape(),
1064 batch_norm_grad->operand(3)->shape(),
1065 batch_norm_grad->operand(4)->shape(),
1066 batch_norm_grad->feature_index()));
1067 }
1068
1069 namespace {
1070
1071 // Checks that the instruction does not have mixed precision floating point
1072 // inputs.
CheckMixedPrecisionOperands(const HloInstruction * instruction)1073 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
1074 switch (instruction->opcode()) {
1075 // Allow-list the following opcodes for mixed-precision check, because
1076 // they involve data pass through or grouping via tuples, where the
1077 // precisions of buffers can be different.
1078 case HloOpcode::kCall:
1079 case HloOpcode::kConditional:
1080 case HloOpcode::kConstant:
1081 case HloOpcode::kConvolution:
1082 case HloOpcode::kDot:
1083 case HloOpcode::kAllReduce:
1084 case HloOpcode::kCopyDone:
1085 case HloOpcode::kCopyStart:
1086 case HloOpcode::kCustomCall:
1087 case HloOpcode::kDomain:
1088 case HloOpcode::kFusion:
1089 case HloOpcode::kGetTupleElement:
1090 case HloOpcode::kInfeed:
1091 case HloOpcode::kOutfeed:
1092 case HloOpcode::kParameter:
1093 case HloOpcode::kRecv:
1094 case HloOpcode::kRecvDone:
1095 case HloOpcode::kReducePrecision:
1096 case HloOpcode::kReduceWindow:
1097 case HloOpcode::kTupleSelect:
1098 case HloOpcode::kSend:
1099 case HloOpcode::kSendDone:
1100 case HloOpcode::kSort:
1101 case HloOpcode::kTuple:
1102 case HloOpcode::kWhile:
1103 break;
1104 default: {
1105 PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
1106 for (auto operand : instruction->operands()) {
1107 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1108 operand->shape(),
1109 [&](const Shape& subshape, const ShapeIndex& index) {
1110 if (!ShapeUtil::ElementIsFloating(subshape)) {
1111 return Status::OK();
1112 }
1113 if (fp_type == PRIMITIVE_TYPE_INVALID) {
1114 fp_type = subshape.element_type();
1115 } else if (fp_type != subshape.element_type()) {
1116 return InternalError(
1117 "Seen floating point types of different precisions in "
1118 "%s, but mixed precision is disallowed.",
1119 instruction->ToString());
1120 }
1121 return Status::OK();
1122 }));
1123 }
1124 }
1125 }
1126 return Status::OK();
1127 }
1128
1129 } // namespace
1130
HandleGather(HloInstruction * gather)1131 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
1132 return CheckShape(
1133 gather,
1134 ShapeInference::InferGatherShape(
1135 gather->operand(0)->shape(), gather->operand(1)->shape(),
1136 gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
1137 }
1138
HandleScatter(HloInstruction * scatter)1139 Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
1140 return CheckShape(
1141 scatter, ShapeInference::InferScatterShape(
1142 scatter->operand(0)->shape(), scatter->operand(1)->shape(),
1143 scatter->operand(2)->shape(),
1144 scatter->to_apply()->ComputeProgramShape(),
1145 scatter->scatter_dimension_numbers()));
1146 }
1147
HandleAfterAll(HloInstruction * token)1148 Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
1149 std::vector<const Shape*> operand_shapes;
1150 for (const HloInstruction* operand : token->operands()) {
1151 operand_shapes.push_back(&operand->shape());
1152 }
1153 return CheckShape(token, ShapeUtil::MakeTokenShape());
1154 }
1155
HandleAddDependency(HloInstruction * add_dependency)1156 Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
1157 TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
1158 return CheckShape(add_dependency, add_dependency->operand(0)->shape());
1159 }
1160
HandleGetDimensionSize(HloInstruction * get_size)1161 Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
1162 return CheckShape(get_size,
1163 ShapeInference::InferGetDimensionSizeShape(
1164 get_size->operand(0)->shape(), get_size->dimension()));
1165 }
1166
HandleSetDimensionSize(HloInstruction * set_size)1167 Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) {
1168 return CheckShape(set_size,
1169 ShapeInference::InferSetDimensionSizeShape(
1170 set_size->operand(0)->shape(),
1171 set_size->operand(1)->shape(), set_size->dimension()));
1172 }
1173
CheckShape(const HloInstruction * instruction,const Shape & inferred_shape,bool only_compare_minor_to_major_in_layout)1174 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1175 const Shape& inferred_shape,
1176 bool only_compare_minor_to_major_in_layout) {
1177 // If allow_mixed_precision_ is false, check if there are operands with
1178 // different precisions. We need this check because ShapeInference allows
1179 // mixed precision inputs.
1180 if (!allow_mixed_precision_) {
1181 TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
1182 }
1183
1184 // Check if the output shape matches the expected shape.
1185 //
1186 // We treat BF16 and F32 as compatible types if mixed precision is allowed,
1187 // but only when the instruction defines the BF16/F32 buffer.
1188 bool equal = [&] {
1189 switch (instruction->opcode()) {
1190 // The opcodes below can't have implicit layout conversions, nor can they
1191 // implicitly transform f32 -> bf16. Fundamentally these are either
1192 // reinterpreting existing data (e.g. kBitcast) or shuffling data around
1193 // without modifying it (e.g. kGetTupleElement, kTupleSelect).
1194 case HloOpcode::kBitcast:
1195 case HloOpcode::kCall:
1196 case HloOpcode::kConditional:
1197 case HloOpcode::kConstant:
1198 case HloOpcode::kCopyDone:
1199 case HloOpcode::kCopyStart:
1200 case HloOpcode::kCustomCall:
1201 case HloOpcode::kDynamicUpdateSlice:
1202 case HloOpcode::kGetTupleElement:
1203 case HloOpcode::kInfeed:
1204 case HloOpcode::kOutfeed:
1205 case HloOpcode::kParameter:
1206 case HloOpcode::kRecv:
1207 case HloOpcode::kRecvDone:
1208 case HloOpcode::kSend:
1209 case HloOpcode::kSendDone:
1210 case HloOpcode::kTuple:
1211 case HloOpcode::kTupleSelect:
1212 case HloOpcode::kWhile:
1213 return ShapesSame(instruction->shape(), inferred_shape,
1214 only_compare_minor_to_major_in_layout);
1215
1216 // We allow arbitrary layout and f32->bf16 transformations on all other
1217 // instructions, although this may be made more strict pending discussion
1218 // in b/112709536.
1219 default:
1220 if (allow_mixed_precision_) {
1221 return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
1222 inferred_shape);
1223 } else {
1224 return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
1225 }
1226 }
1227 }();
1228 if (!equal) {
1229 return InternalError(
1230 "Expected instruction to have shape equal to %s, actual "
1231 "shape is %s:\n%s",
1232 StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
1233 instruction->ToString());
1234 }
1235 return Status::OK();
1236 }
1237
CheckShape(const HloInstruction * instruction,const StatusOr<Shape> & inferred_shape_status)1238 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
1239 const StatusOr<Shape>& inferred_shape_status) {
1240 if (!inferred_shape_status.ok()) {
1241 Status s = inferred_shape_status.status();
1242 tensorflow::errors::AppendToMessage(&s, ", for instruction ",
1243 instruction->ToString());
1244 return s;
1245 }
1246 return CheckShape(instruction, inferred_shape_status.ValueOrDie());
1247 }
1248
CheckUnaryShape(const HloInstruction * instruction)1249 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
1250 return CheckShape(instruction,
1251 ShapeInference::InferUnaryOpShape(instruction->opcode(),
1252 instruction->operand(0)));
1253 }
1254
CheckBinaryShape(const HloInstruction * instruction)1255 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
1256 return CheckShape(
1257 instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
1258 instruction->operand(0),
1259 instruction->operand(1)));
1260 }
1261
CheckTernaryShape(const HloInstruction * instruction)1262 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
1263 return CheckShape(instruction,
1264 ShapeInference::InferTernaryOpShape(
1265 instruction->opcode(), instruction->operand(0),
1266 instruction->operand(1), instruction->operand(2)));
1267 }
1268
CheckVariadicShape(const HloInstruction * instruction)1269 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
1270 return CheckShape(instruction,
1271 ShapeInference::InferVariadicOpShape(
1272 instruction->opcode(), instruction->operands()));
1273 }
1274
VerifyEntryComputationLayout(const HloModule & module)1275 Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
1276 const HloComputation* computation = module.entry_computation();
1277 const auto& layout = module.entry_computation_layout();
1278 const ShapeLayout& result_layout = layout.result_layout();
1279
1280 TF_RETURN_IF_ERROR(
1281 ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
1282
1283 if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
1284 result_layout.shape())) {
1285 return InternalError(
1286 "Shape of the root instruction of entry computation (%s) should be "
1287 "compatible to one specified in module's entry computation layout (%s)",
1288 ShapeUtil::HumanString(computation->root_instruction()->shape()),
1289 ShapeUtil::HumanString(result_layout.shape()));
1290 }
1291
1292 if (computation->num_parameters() != layout.parameter_count()) {
1293 return InternalError(
1294 "Number of parameters in entry computation layout (%d) must be same "
1295 "as number of parameters of entry computation (%d)",
1296 layout.parameter_count(), computation->num_parameters());
1297 }
1298
1299 for (int i = 0; i < computation->num_parameters(); ++i) {
1300 const HloInstruction* parameter = computation->parameter_instruction(i);
1301 TF_RETURN_IF_ERROR(
1302 ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
1303 if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
1304 return InternalError(
1305 "Shape of the entry computation parameter %d is %s should be "
1306 "compatible to the one specified in module's entry computation "
1307 "layout %s",
1308 i, ShapeUtil::HumanString(parameter->shape()),
1309 ShapeUtil::HumanString(layout.parameter_shape(i)));
1310 }
1311 }
1312
1313 return Status::OK();
1314 }
1315
ComputationsToString(absl::Span<HloComputation * const> computations)1316 string ComputationsToString(absl::Span<HloComputation* const> computations) {
1317 return absl::StrJoin(computations, ",",
1318 [](string* s, const HloComputation* computation) {
1319 s->append(computation->name());
1320 });
1321 }
1322
1323 // Verifies various invariants about the structure of the HLO:
1324 //
1325 // (1) each instruction has a non-null parent() set to the HloComputation
1326 // which
1327 // contains it.
1328 //
1329 // (2) each computation has a non-null parent() set to the HloModule which
1330 // contains it.
1331 //
1332 // (3) the operands of each instruction are in the same computation as the
1333 // instruction.
VerifyHloStructure(HloModule * module)1334 Status VerifyHloStructure(HloModule* module) {
1335 for (const HloComputation* computation : module->computations()) {
1336 if (computation->parent() == nullptr) {
1337 return InternalError("Computation %s has a null parent pointer",
1338 computation->name());
1339 }
1340 if (computation->parent() != module) {
1341 return InternalError(
1342 "Computation %s parent() does not point to parent module",
1343 computation->name());
1344 }
1345
1346 for (const HloInstruction* instruction : computation->instructions()) {
1347 if (instruction->parent() == nullptr) {
1348 return InternalError("Instruction %s has a null parent pointer",
1349 instruction->name());
1350 }
1351 if (instruction->parent() != computation) {
1352 return InternalError(
1353 "Instruction %s parent() does not point to parent computation",
1354 instruction->name());
1355 }
1356 }
1357 }
1358
1359 // Check that operands are in the same computation separately from verifying
1360 // parent() correctness so conditions like a null HloInstruction::parent()
1361 // are identified and reported explicitly above rather than reporting a
1362 // mismatched operand.
1363 for (const HloComputation* computation : module->computations()) {
1364 for (const HloInstruction* instruction : computation->instructions()) {
1365 for (int i = 0; i < instruction->operand_count(); ++i) {
1366 const HloInstruction* operand = instruction->operand(i);
1367 if (operand->parent() != instruction->parent()) {
1368 return InternalError(
1369 "Operand %d (%s) of instruction %s is in a different "
1370 "computation: %s vs %s",
1371 i, operand->name(), instruction->name(),
1372 operand->parent() ? operand->parent()->name() : "(null)",
1373 instruction->parent()->name());
1374 }
1375 }
1376 }
1377 }
1378 return Status::OK();
1379 }
1380
1381 namespace {
1382
1383 // Returns true if the given Shape has a TOKEN shape as any subshape.
ShapeContainsToken(const Shape & shape)1384 bool ShapeContainsToken(const Shape& shape) {
1385 bool contains_token = false;
1386 ShapeUtil::ForEachSubshape(
1387 shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
1388 if (subshape.IsToken()) {
1389 contains_token = true;
1390 }
1391 });
1392 return contains_token;
1393 }
1394
1395 // Verifies that all types entering and exiting the entry computation are
1396 // legal.
VerifyEntryAndExitShapes(const HloModule & module)1397 Status VerifyEntryAndExitShapes(const HloModule& module) {
1398 // Tokens cannot be passed as entry parameters.
1399 // TODO(b/80000000): Remove this constraint.
1400 for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
1401 HloInstruction* param =
1402 module.entry_computation()->parameter_instruction(i);
1403 if (ShapeContainsToken(param->shape())) {
1404 return InternalError(
1405 "Entry parameter %d is or contains a token shape: %s", i,
1406 ShapeUtil::HumanString(param->shape()));
1407 }
1408 }
1409 return Status::OK();
1410 }
1411
1412 // Checks if the given two instructions share the same channel id.
CheckSameChannel(const HloInstruction * instr1,const HloInstruction * instr2)1413 Status CheckSameChannel(const HloInstruction* instr1,
1414 const HloInstruction* instr2) {
1415 if (instr1->channel_id() != instr2->channel_id()) {
1416 return InternalError(
1417 "Expected to have the same channel id, actual channel ids are: %s "
1418 "(%d), %s (%d)",
1419 instr1->ToString(), *instr1->channel_id(), instr2->ToString(),
1420 *instr2->channel_id());
1421 }
1422 return Status::OK();
1423 }
1424
1425 // Checks if the given two instructions have the same is_host_transfer
1426 // attribute value. Intsructions must be send/recv instructions or their
1427 // 'done' variant.
CheckSameIsHostTransfer(const HloInstruction * instr1,const HloInstruction * instr2)1428 Status CheckSameIsHostTransfer(const HloInstruction* instr1,
1429 const HloInstruction* instr2) {
1430 const HloSendRecvInstruction* send_recv1 =
1431 DynCast<const HloSendRecvInstruction>(instr1);
1432 const HloSendRecvInstruction* send_recv2 =
1433 DynCast<const HloSendRecvInstruction>(instr2);
1434 TF_RET_CHECK(send_recv1 != nullptr);
1435 TF_RET_CHECK(send_recv2 != nullptr);
1436 if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
1437 return InternalError(
1438 "Expected instructions to have the same is-host-transfer property: "
1439 "%s, "
1440 "%s ",
1441 instr1->ToString(), instr2->ToString());
1442 }
1443 return Status::OK();
1444 }
1445
VerifySingleUser(const HloInstruction * instruction,HloOpcode expected_user)1446 Status VerifySingleUser(const HloInstruction* instruction,
1447 HloOpcode expected_user) {
1448 TF_RET_CHECK(instruction->users().size() == 1)
1449 << "The " << HloOpcodeString(instruction->opcode())
1450 << " instruction requires one consumer, found "
1451 << instruction->users().size();
1452
1453 const HloInstruction* user = instruction->users().front();
1454 TF_RET_CHECK(user->opcode() == expected_user)
1455 << "The consumer of a " << HloOpcodeString(instruction->opcode())
1456 << " instruction needs to be " << HloOpcodeString(expected_user)
1457 << ", found " << HloOpcodeString(user->opcode());
1458 return Status::OK();
1459 }
1460
VerifySingleOperand(const HloInstruction * instruction,HloOpcode expected_operand)1461 Status VerifySingleOperand(const HloInstruction* instruction,
1462 HloOpcode expected_operand) {
1463 TF_RET_CHECK(instruction->operands().size() == 1)
1464 << "The " << HloOpcodeString(instruction->opcode())
1465 << " instruction requires one consumer, found "
1466 << instruction->users().size();
1467
1468 const HloInstruction* operand = instruction->operand(0);
1469 TF_RET_CHECK(operand->opcode() == expected_operand)
1470 << "The operand of a " << HloOpcodeString(instruction->opcode())
1471 << " instruction needs to be " << HloOpcodeString(expected_operand)
1472 << ", found " << HloOpcodeString(operand->opcode());
1473 return Status::OK();
1474 }
1475
1476 // Checks asynchronous instruction pairs.
VerifyAsynchronousInstructionPairs(const HloModule & module)1477 Status VerifyAsynchronousInstructionPairs(const HloModule& module) {
1478 // CopyStart must have a single CopyDone user.
1479 for (const HloComputation* computation : module.computations()) {
1480 for (const HloInstruction* instruction : computation->instructions()) {
1481 switch (instruction->opcode()) {
1482 case HloOpcode::kCopyStart: {
1483 TF_RETURN_IF_ERROR(
1484 VerifySingleUser(instruction, HloOpcode::kCopyDone));
1485 break;
1486 }
1487 case HloOpcode::kCopyDone: {
1488 TF_RETURN_IF_ERROR(
1489 VerifySingleOperand(instruction, HloOpcode::kCopyStart));
1490 break;
1491 }
1492 case HloOpcode::kCollectivePermuteStart: {
1493 TF_RETURN_IF_ERROR(
1494 VerifySingleUser(instruction, HloOpcode::kCollectivePermuteDone));
1495 break;
1496 }
1497 case HloOpcode::kCollectivePermuteDone: {
1498 TF_RETURN_IF_ERROR(VerifySingleOperand(
1499 instruction, HloOpcode::kCollectivePermuteStart));
1500 break;
1501 }
1502 default:
1503 break;
1504 }
1505 }
1506 }
1507 return Status::OK();
1508 }
1509
1510 // Checks that AllReduce instructions in the module are either all layout
1511 // constrained or all unconstrained.
VerifyLayoutConstrainedAllReduce(const HloModule & module)1512 Status VerifyLayoutConstrainedAllReduce(const HloModule& module) {
1513 const HloAllReduceInstruction* reference = nullptr;
1514 for (const HloComputation* computation : module.computations()) {
1515 for (const HloInstruction* instruction : computation->instructions()) {
1516 if (instruction->opcode() != HloOpcode::kAllReduce) {
1517 continue;
1518 }
1519 auto all_reduce = DynCast<HloAllReduceInstruction>(instruction);
1520 if (!reference) {
1521 reference = all_reduce;
1522 }
1523 if (reference->constrain_layout() != all_reduce->constrain_layout()) {
1524 return FailedPrecondition(
1525 "HloModule has a mix of layout constrained and unconstrained "
1526 "AllReduce instructions.");
1527 }
1528 }
1529 }
1530 return Status::OK();
1531 }
1532
1533 // Checks various invariants of channel instructions (send/recv and
1534 // collectives).
VerifyChannels(const HloModule & module)1535 Status VerifyChannels(const HloModule& module) {
1536 absl::flat_hash_map<int64, std::vector<const HloInstruction*>>
1537 channel_instructions;
1538
1539 // Send/Recv instruction must have a single user: the corresponding
1540 // SendDone/RecvDone. with matching channel.
1541 for (const HloComputation* computation : module.computations()) {
1542 for (const HloInstruction* instruction : computation->instructions()) {
1543 auto channel_instr = DynCast<HloChannelInstruction>(instruction);
1544 if (!channel_instr || !channel_instr->channel_id()) {
1545 continue;
1546 }
1547 channel_instructions[*channel_instr->channel_id()].push_back(instruction);
1548
1549 switch (instruction->opcode()) {
1550 case HloOpcode::kSend: {
1551 TF_RET_CHECK(instruction->users().size() == 1);
1552 const HloInstruction* send_done = instruction->users().front();
1553 TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
1554 TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
1555 TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
1556 break;
1557 }
1558 case HloOpcode::kRecv: {
1559 TF_RET_CHECK(instruction->users().size() == 1);
1560 const HloInstruction* recv_done = instruction->users().front();
1561 TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
1562 TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
1563 TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
1564 break;
1565 }
1566 case HloOpcode::kSendDone:
1567 TF_RET_CHECK(instruction->operands().size() == 1);
1568 TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
1569 break;
1570 case HloOpcode::kRecvDone:
1571 TF_RET_CHECK(instruction->operands().size() == 1);
1572 TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
1573 break;
1574 default:
1575 break;
1576 }
1577 }
1578 }
1579
1580 // Iterate over each channel to check invariants.
1581 for (auto& pair : channel_instructions) {
1582 auto& instructions = pair.second;
1583 const HloInstruction* first = instructions[0];
1584 auto sendrecv = DynCast<HloSendRecvInstruction>(first);
1585 if (sendrecv) {
1586 absl::flat_hash_set<HloOpcode> opcodes;
1587 for (const HloInstruction* instr : instructions) {
1588 opcodes.insert(instr->opcode());
1589 auto cast = DynCast<HloSendRecvInstruction>(instr);
1590 TF_RET_CHECK(cast != nullptr)
1591 << "channel " << pair.first
1592 << " is used for different types of channel instructions";
1593 }
1594 if (sendrecv->is_host_transfer()) {
1595 TF_RET_CHECK(instructions.size() == 2)
1596 << "channel " << pair.first
1597 << " is used for multiple host send/recv instructions";
1598 } else {
1599 TF_RET_CHECK(instructions.size() == opcodes.size())
1600 << "channel " << pair.first
1601 << " is used for multiple send/recv instructions";
1602 }
1603 } else {
1604 for (const HloInstruction* instr : instructions) {
1605 TF_RET_CHECK(first->opcode() == instr->opcode())
1606 << "channel " << pair.first
1607 << " is used for different types of channel instructions";
1608 }
1609 }
1610 }
1611
1612 return Status::OK();
1613 }
1614
1615 // CHECKs various invariants of a fusion instruction.
CheckFusionInstruction(HloInstruction * fusion)1616 Status CheckFusionInstruction(HloInstruction* fusion) {
1617 // The parent fusion instruction of the fusion computation must be 'fusion'.
1618 HloComputation* fused_computation = fusion->fused_instructions_computation();
1619 if (fusion != fused_computation->FusionInstruction()) {
1620 return InternalError(
1621 "Instruction of fused computation does not match expected "
1622 "instruction "
1623 "%s.",
1624 fusion->ToString());
1625 }
1626
1627 // Fused root instruction and fused parameters must all be owned by the
1628 // fusion computation.
1629 bool root_owned = false;
1630 const std::vector<HloInstruction*>& fused_parameters =
1631 fusion->fused_parameters();
1632 const HloInstruction* fused_root = fusion->fused_expression_root();
1633 std::vector<bool> parameter_owned(fused_parameters.size(), false);
1634 for (auto* instruction : fused_computation->instructions()) {
1635 if (fused_root == instruction) {
1636 if (root_owned) {
1637 return InternalError("Root appears more than once in %s.",
1638 fusion->ToString());
1639 }
1640 root_owned = true;
1641 }
1642 for (int i = 0; i < fused_parameters.size(); ++i) {
1643 if (fused_parameters[i] == instruction) {
1644 if (parameter_owned[i]) {
1645 return InternalError("Parameter appears more than once in %s.",
1646 fusion->ToString());
1647 }
1648 parameter_owned[i] = true;
1649 }
1650 }
1651 }
1652 if (!root_owned) {
1653 return InternalError("Root not found in computation of %s.",
1654 fusion->ToString());
1655 }
1656 // Make sure all the parameter_owned entries are set
1657 for (int i = 0; i < parameter_owned.size(); i++) {
1658 if (!parameter_owned[i]) {
1659 return InternalError("Parameter %d not found in computation of %s.", i,
1660 fusion->ToString());
1661 }
1662 }
1663
1664 // Fused root must have no users.
1665 if (fused_root->user_count() != 0) {
1666 return InternalError("Root of %s may not have users.", fusion->ToString());
1667 }
1668
1669 // All uses of fused instructions must be in the fusion computation, and
1670 // every non-root instruction must have at least one use.
1671 for (auto* instruction :
1672 fusion->fused_instructions_computation()->instructions()) {
1673 if (instruction != fused_root) {
1674 if (instruction->user_count() == 0) {
1675 return InternalError("Non-root instruction %s in %s must have users.",
1676 instruction->ToString(), fusion->ToString());
1677 }
1678 for (auto& user : instruction->users()) {
1679 if (fused_computation != user->parent()) {
1680 return InternalError(
1681 "Non-root instruction %s in %s may not have external users.",
1682 instruction->ToString(), fusion->ToString());
1683 }
1684 }
1685 }
1686 }
1687
1688 // Fused parameter instructions must be numbered contiguously and match up
1689 // (shapes equal) with their respective operand.
1690 CHECK_EQ(fusion->operands().size(), fused_parameters.size());
1691 std::vector<bool> parameter_numbers(fused_parameters.size(), false);
1692 for (auto fused_param : fused_parameters) {
1693 int64 param_no = fused_param->parameter_number();
1694 if (param_no < 0) {
1695 return InternalError("Unexpected negative parameter number %d in %s.",
1696 param_no, fusion->ToString());
1697 }
1698 if (param_no >= fused_parameters.size()) {
1699 return InternalError(
1700 "Unexpected parameter number %d in %s: higher then number of "
1701 "parameters %lu.",
1702 param_no, fusion->ToString(), fused_parameters.size());
1703 }
1704 if (parameter_numbers[param_no]) {
1705 return InternalError(
1706 "Did not expect parameter number %d more than once in %s.", param_no,
1707 fusion->ToString());
1708 }
1709 parameter_numbers[param_no] = true;
1710 }
1711 // Make sure all the parameter_numbers entries were seen.
1712 for (int i = 0; i < parameter_numbers.size(); i++) {
1713 if (!parameter_numbers[i]) {
1714 return InternalError("Did not see parameter number %d in %s.", i,
1715 fusion->ToString());
1716 }
1717 }
1718
1719 TF_RET_CHECK(fusion->called_computations() ==
1720 absl::Span<HloComputation* const>(
1721 {fusion->fused_instructions_computation()}))
1722 << "Fusion HLO calls computations other than the "
1723 "fused_instructions_computation: "
1724 << fusion->ToString() << " fusion->fused_instructions_computation(): "
1725 << fusion->fused_instructions_computation()->ToString()
1726 << " fusion->called_computations(): "
1727 << ComputationsToString(fusion->called_computations());
1728
1729 for (const auto& fused : fusion->fused_instructions()) {
1730 TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
1731 << "Fused HLO was missing a parent: " << fused->ToString()
1732 << " parent: " << fused->parent()
1733 << " computation: " << fusion->parent();
1734 }
1735
1736 // TODO(b/65423525): We'd like to check that all operands are distinct.
1737 // This is currently disabled due to the invariant being violated by
1738 // multi-output fusion.
1739 return Status::OK();
1740 }
1741
1742 // Checks that the operand shapes are compatible to the output shape, i.e.,
1743 // that there are no implicit broadcasts.
CheckElementwiseInstruction(HloInstruction * instruction)1744 Status CheckElementwiseInstruction(HloInstruction* instruction) {
1745 const Shape& out_shape = instruction->shape();
1746 for (HloInstruction* operand : instruction->operands()) {
1747 const Shape& operand_shape = operand->shape();
1748 if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
1749 return FailedPrecondition(
1750 "Implicit broadcast is not allowed in HLO."
1751 "Found different shapes for instruction %s.\n"
1752 "output: %s\noperand: %s\n",
1753 HloOpcodeString(instruction->opcode()),
1754 ShapeUtil::HumanString(out_shape),
1755 ShapeUtil::HumanString(operand_shape));
1756 }
1757 }
1758 if (auto* comparison = DynCast<HloCompareInstruction>(instruction)) {
1759 const Shape& operand_shape = comparison->operand(1)->shape();
1760 PrimitiveType operand_element_type = operand_shape.element_type();
1761 Comparison::Type default_comparison_type =
1762 Comparison::DefaultComparisonType(operand_element_type);
1763 if (primitive_util::IsFloatingPointType(operand_element_type)) {
1764 if (comparison->type() != Comparison::Type::kFloat &&
1765 comparison->type() != Comparison::Type::kFloatTotalOrder) {
1766 return FailedPrecondition(
1767 "Expected comparison type %s or %s.\n"
1768 "actual: %s\noperand: %s\n",
1769 ComparisonTypeToString(Comparison::Type::kFloat),
1770 ComparisonTypeToString(Comparison::Type::kFloatTotalOrder),
1771 ComparisonTypeToString(comparison->type()),
1772 ShapeUtil::HumanString(operand_shape));
1773 }
1774 } else if (comparison->type() != default_comparison_type) {
1775 return FailedPrecondition(
1776 "Expected comparison type %s.\n"
1777 "actual: %s\noperand: %s\n",
1778 ComparisonTypeToString(default_comparison_type),
1779 ComparisonTypeToString(comparison->type()),
1780 ShapeUtil::HumanString(operand_shape));
1781 }
1782 }
1783 return Status::OK();
1784 }
1785
1786 // Visitor which verifies various fields on the HLO instruction. This class does
1787 // not check result shape as that is checked in the ShapeVerifier.
1788 class InstructionVerifier : public DfsHloVisitorWithDefault {
1789 public:
InstructionVerifier(std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)1790 explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
1791 instruction_can_change_layout_func)
1792 : instruction_can_change_layout_func_(
1793 instruction_can_change_layout_func) {}
1794
DefaultAction(HloInstruction *)1795 Status DefaultAction(HloInstruction*) override { return Status::OK(); }
1796
HandleFusion(HloInstruction * fusion)1797 Status HandleFusion(HloInstruction* fusion) override {
1798 return CheckFusionInstruction(fusion);
1799 }
1800
HandleBroadcast(HloInstruction * broadcast)1801 Status HandleBroadcast(HloInstruction* broadcast) override {
1802 // If you see this failure then someone has confused the difference
1803 // between the HLO broadcast op, and the UserComputation broadcast
1804 // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
1805 // or ComputationLowerer::Visit()
1806 TF_RET_CHECK(broadcast->dimensions().size() ==
1807 broadcast->operand(0)->shape().rank())
1808 << "Broadcast HLO (" << broadcast->ToShortString()
1809 << ") has invalid number of dimensions: "
1810 << broadcast->dimensions().size()
1811 << " != " << broadcast->operand(0)->shape().rank();
1812 return Status::OK();
1813 }
1814
HandleWhile(HloInstruction * xla_while)1815 Status HandleWhile(HloInstruction* xla_while) override {
1816 auto* while_cond = xla_while->while_condition();
1817 auto* while_body = xla_while->while_body();
1818 if (while_cond->num_parameters() != 1) {
1819 return FailedPrecondition(
1820 "While condition must have exactly 1 parameter; had %d : %s",
1821 while_cond->num_parameters(), while_cond->ToString());
1822 }
1823 if (while_body->num_parameters() != 1) {
1824 return FailedPrecondition(
1825 "While body must have exactly 1 parameter; had %d : %s",
1826 while_body->num_parameters(), while_body->ToString());
1827 }
1828 if (xla_while->operand_count() != 1) {
1829 return FailedPrecondition(
1830 "While loop must have exactly one operand; had %d : %s",
1831 xla_while->operand_count(), xla_while->ToString());
1832 }
1833 return Status::OK();
1834 }
1835
HandleConditional(HloInstruction * conditional)1836 Status HandleConditional(HloInstruction* conditional) override {
1837 for (int b = 0; b < conditional->branch_count(); ++b) {
1838 if (conditional->branch_computation(b)->num_parameters() != 1) {
1839 return FailedPrecondition(
1840 "Branch computation %s of %s must have 1 parameter instead of %d",
1841 conditional->branch_computation(b)->name(), conditional->ToString(),
1842 conditional->branch_computation(b)->num_parameters());
1843 }
1844 }
1845 return Status::OK();
1846 }
1847
HandleElementwiseUnary(HloInstruction * instruction)1848 Status HandleElementwiseUnary(HloInstruction* instruction) override {
1849 return CheckElementwiseInstruction(instruction);
1850 }
1851
HandleElementwiseBinary(HloInstruction * instruction)1852 Status HandleElementwiseBinary(HloInstruction* instruction) override {
1853 return CheckElementwiseInstruction(instruction);
1854 }
1855
HandleGetTupleElement(HloInstruction * gte)1856 Status HandleGetTupleElement(HloInstruction* gte) override {
1857 TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
1858 return Status::OK();
1859 }
1860
HandleTranspose(HloInstruction * transpose)1861 Status HandleTranspose(HloInstruction* transpose) override {
1862 const Shape& shape = transpose->shape();
1863 const HloInstruction* operand = transpose->operand(0);
1864 TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
1865 TF_RET_CHECK(shape.dimensions().size() ==
1866 transpose->operand(0)->shape().dimensions().size());
1867 TF_RET_CHECK(std::equal(
1868 shape.dimensions().begin(), shape.dimensions().end(),
1869 Permute(operand->shape().dimensions(), transpose->dimensions())
1870 .begin()))
1871 << "shape: " << shape << ", operand->shape(): " << shape
1872 << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
1873 << "}";
1874 return Status::OK();
1875 }
1876
HandleAllReduce(HloInstruction * crs)1877 Status HandleAllReduce(HloInstruction* crs) override {
1878 if (crs->channel_id().has_value()) {
1879 TF_RET_CHECK(crs->channel_id().value() > 0)
1880 << "All reduce channel id must be greater than 0 for "
1881 << crs->ToShortString();
1882 }
1883 return Status::OK();
1884 }
1885
Preprocess(HloInstruction * instruction)1886 Status Preprocess(HloInstruction* instruction) override {
1887 auto previous = instructions_by_name_.find(instruction->name());
1888 TF_RET_CHECK(previous == instructions_by_name_.end())
1889 << "HLO has name that is not unique within module:\n"
1890 << instruction->ToString()
1891 << " in computation: " << instruction->parent()->name()
1892 << "\nPrevious HLO with same name:\n"
1893 << previous->second->ToString()
1894 << " in computation: " << previous->second->parent()->name();
1895 instructions_by_name_[instruction->name()] = instruction;
1896 return Status::OK();
1897 }
1898
Postprocess(HloInstruction * instruction)1899 Status Postprocess(HloInstruction* instruction) override {
1900 if (instruction_can_change_layout_func_ &&
1901 LayoutUtil::IsDenseArray(instruction->shape()) &&
1902 !instruction_can_change_layout_func_(instruction)) {
1903 const Shape& result_shape = instruction->shape();
1904 const Layout& result_layout = result_shape.layout();
1905 for (HloInstruction* operand : instruction->operands()) {
1906 const Shape& operand_shape = operand->shape();
1907 if (LayoutUtil::IsDenseArray(operand_shape) &&
1908 operand_shape.rank() == result_shape.rank()) {
1909 const Layout& operand_layout = operand_shape.layout();
1910 TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
1911 << "Instruction shouldn't change layouts "
1912 << instruction->ToString() << " From " << result_shape << " To "
1913 << operand_shape;
1914 }
1915 }
1916 }
1917
1918 return Status::OK();
1919 }
1920
1921 private:
1922 absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
1923 // Determines whether an instruction can change layouts.
1924 std::function<bool(const HloInstruction*)>
1925 instruction_can_change_layout_func_;
1926 };
1927
1928 } // namespace
1929
Run(HloModule * module)1930 StatusOr<bool> HloVerifier::Run(HloModule* module) {
1931 TF_RET_CHECK(!module->name().empty());
1932
1933 if (module->entry_computation()->IsFusionComputation()) {
1934 return InvalidArgument(
1935 "Module entry computation cannot be a fusion computation");
1936 }
1937
1938 TF_RETURN_IF_ERROR(VerifyHloStructure(module));
1939 TF_RETURN_IF_ERROR(VerifyAsynchronousInstructionPairs(*module));
1940 TF_RETURN_IF_ERROR(VerifyChannels(*module));
1941
1942 std::unique_ptr<ShapeVerifier> shape_verifier =
1943 target_metadata_->GetVerifier();
1944 InstructionVerifier instruction_verifier(instruction_can_change_layout_func_);
1945 for (auto* computation : module->computations()) {
1946 TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
1947 TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
1948 }
1949
1950 TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
1951 TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
1952
1953 // If the module has a schedule, it must be valid.
1954 if (module->has_schedule()) {
1955 TF_RETURN_IF_ERROR(module->schedule().Verify());
1956 }
1957
1958 TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
1959 *module, [this](const Shape& shape) -> int64 {
1960 if (target_metadata_->IsLayoutSensitive()) {
1961 return target_metadata_->ShapeSize(shape);
1962 } else {
1963 return 0;
1964 }
1965 }));
1966
1967 TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
1968 TF_RETURN_IF_ERROR(VerifyLayoutConstrainedAllReduce(*module));
1969
1970 return false;
1971 }
1972
1973 } // namespace xla
1974