1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <set>
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/strings/str_join.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
24 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
25 #include "tensorflow/compiler/xla/status_macros.h"
26 #include "tensorflow/compiler/xla/util.h"
27 #include "tensorflow/core/lib/core/errors.h"
28
29 namespace xla {
30
VerifyNotSparse(const Shape & shape)31 Status VerifyNotSparse(const Shape& shape) {
32 return ShapeUtil::ForEachSubshapeWithStatus(
33 shape, [](const Shape& subshape, const ShapeIndex&) -> Status {
34 if (LayoutUtil::IsSparseArray(subshape)) {
35 return InternalError("Sparse arrays are not yet fully supported: %s",
36 ShapeUtil::HumanStringWithLayout(subshape));
37 }
38 return Status::OK();
39 });
40 }
41
IsCallerInstruction(HloInstruction * hlo)42 bool IsCallerInstruction(HloInstruction* hlo) {
43 switch (hlo->opcode()) {
44 case HloOpcode::kCall:
45 case HloOpcode::kConditional:
46 case HloOpcode::kWhile:
47 case HloOpcode::kAllReduce:
48 case HloOpcode::kMap:
49 case HloOpcode::kReduce:
50 case HloOpcode::kReduceWindow:
51 case HloOpcode::kScatter:
52 case HloOpcode::kSelectAndScatter:
53 case HloOpcode::kSort:
54 case HloOpcode::kFusion:
55 return true;
56 default:
57 return false;
58 }
59 }
60
61 namespace {
62
CheckOperandCount(const HloInstruction * hlo,int expected)63 Status CheckOperandCount(const HloInstruction* hlo, int expected) {
64 if (hlo->operand_count() != expected) {
65 return InternalError("Expected %d operands for %s instruction: %s",
66 expected, HloOpcodeString(hlo->opcode()),
67 hlo->ToString());
68 }
69 return Status::OK();
70 }
71
CheckParameterCount(const HloInstruction * calling_instruction,const HloComputation * computation,int expected)72 Status CheckParameterCount(const HloInstruction* calling_instruction,
73 const HloComputation* computation, int expected) {
74 if (computation->num_parameters() != expected) {
75 return InternalError(
76 "Expected computation %s called from %s to have %d parameters, has %d",
77 computation->name(), calling_instruction->name(), expected,
78 computation->num_parameters());
79 }
80 return Status::OK();
81 }
82
83 } // namespace
84
Preprocess(HloInstruction * hlo)85 Status ShapeVerifier::Preprocess(HloInstruction* hlo) {
86 if (!hlo->called_computations().empty() && !IsCallerInstruction(hlo)) {
87 return InternalError(
88 "Called computations specified for non-caller instruction %s",
89 hlo->ToString());
90 }
91 TF_RETURN_IF_ERROR(VerifyNotSparse(hlo->shape()));
92
93 absl::optional<int> arity = HloOpcodeArity(hlo->opcode());
94 if (arity) {
95 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, *arity));
96 }
97 return Status::OK();
98 }
99
HandleElementwiseUnary(HloInstruction * hlo)100 Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
101 return CheckUnaryShape(hlo);
102 }
103
HandleElementwiseBinary(HloInstruction * hlo)104 Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) {
105 return CheckBinaryShape(hlo);
106 }
107
HandleClamp(HloInstruction * clamp)108 Status ShapeVerifier::HandleClamp(HloInstruction* clamp) {
109 return CheckTernaryShape(clamp);
110 }
111
HandleSelect(HloInstruction * select)112 Status ShapeVerifier::HandleSelect(HloInstruction* select) {
113 return CheckTernaryShape(select);
114 }
115
HandleTupleSelect(HloInstruction * tuple_select)116 Status ShapeVerifier::HandleTupleSelect(HloInstruction* tuple_select) {
117 return CheckTernaryShape(tuple_select);
118 }
119
HandleConcatenate(HloInstruction * concatenate)120 Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
121 std::vector<const Shape*> operand_shapes;
122 for (const HloInstruction* operand : concatenate->operands()) {
123 operand_shapes.push_back(&operand->shape());
124 }
125 return CheckShape(concatenate,
126 ShapeInference::InferConcatOpShape(
127 operand_shapes, concatenate->concatenate_dimension()));
128 }
129
HandleConvert(HloInstruction * convert)130 Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
131 return CheckShape(convert, ShapeInference::InferConvertShape(
132 convert->operand(0)->shape(),
133 convert->shape().element_type()));
134 }
135
HandleBitcastConvert(HloInstruction * convert)136 Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
137 return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
138 convert->operand(0)->shape(),
139 convert->shape().element_type()));
140 }
141
HandleCopy(HloInstruction * copy)142 Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
143 return CheckUnaryShape(copy);
144 }
145
HandleDot(HloInstruction * dot)146 Status ShapeVerifier::HandleDot(HloInstruction* dot) {
147 TF_ASSIGN_OR_RETURN(const Shape expected,
148 ShapeInference::InferDotOpShape(
149 dot->operand(0)->shape(), dot->operand(1)->shape(),
150 dot->dot_dimension_numbers()));
151 return CheckShape(dot, expected);
152 }
153
HandleConvolution(HloInstruction * convolution)154 Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
155 TF_ASSIGN_OR_RETURN(
156 const Shape expected,
157 ShapeInference::InferConvolveShape(
158 convolution->operand(0)->shape(), convolution->operand(1)->shape(),
159 convolution->feature_group_count(), convolution->batch_group_count(),
160 convolution->window(), convolution->convolution_dimension_numbers()));
161 return CheckShape(convolution, expected);
162 }
163
HandleFft(HloInstruction * fft)164 Status ShapeVerifier::HandleFft(HloInstruction* fft) {
165 TF_ASSIGN_OR_RETURN(
166 const Shape expected,
167 ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
168 fft->fft_length()));
169 return CheckShape(fft, expected);
170 }
171
HandleTriangularSolve(HloInstruction * hlo)172 Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
173 TF_ASSIGN_OR_RETURN(const Shape expected,
174 ShapeInference::InferTriangularSolveShape(
175 hlo->operand(0)->shape(), hlo->operand(1)->shape(),
176 hlo->triangular_solve_options()));
177 return CheckShape(hlo, expected);
178 }
179
HandleCholesky(HloInstruction * hlo)180 Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
181 TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
182 TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
183 hlo->operand(0)->shape()));
184 return CheckShape(hlo, expected);
185 }
186
HandleAllReduce(HloInstruction * crs)187 Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
188 std::vector<const Shape*> operand_shapes;
189 for (const HloInstruction* operand : crs->operands()) {
190 operand_shapes.push_back(&operand->shape());
191 }
192 return CheckShape(crs, ShapeInference::InferAllReduceShape(operand_shapes));
193 }
194
HandleAllToAll(HloInstruction * hlo)195 Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
196 std::vector<const Shape*> operand_shapes;
197 for (const HloInstruction* operand : hlo->operands()) {
198 operand_shapes.push_back(&operand->shape());
199 }
200 return CheckShape(hlo,
201 ShapeInference::InferAllToAllTupleShape(operand_shapes));
202 }
203
HandleReplicaId(HloInstruction * hlo)204 Status ShapeVerifier::HandleReplicaId(HloInstruction* hlo) {
205 return CheckShape(hlo, ShapeUtil::MakeShape(U32, {}));
206 }
207
HandleCollectivePermute(HloInstruction * hlo)208 Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
209 return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
210 hlo->operand(0)->shape()));
211 }
212
HandleReducePrecision(HloInstruction * reduce_precision)213 Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
214 return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
215 reduce_precision->operand(0)->shape(),
216 reduce_precision->exponent_bits(),
217 reduce_precision->mantissa_bits()));
218 }
219
CheckIsTokenOperand(const HloInstruction * instruction,int64 operand_no)220 Status ShapeVerifier::CheckIsTokenOperand(const HloInstruction* instruction,
221 int64 operand_no) {
222 const HloInstruction* token = instruction->operand(operand_no);
223 if (!ShapeUtil::Equal(token->shape(), ShapeUtil::MakeTokenShape())) {
224 return InternalError(
225 "Expected operand %d to be token-shaped, actual shape is "
226 "%s:\n%s",
227 operand_no, StringifyShape(token->shape()), instruction->ToString());
228 }
229 return Status::OK();
230 }
231
CheckOperandAndParameter(const HloInstruction * instruction,int64 operand_number,const HloComputation * computation,int64 parameter_number)232 Status ShapeVerifier::CheckOperandAndParameter(
233 const HloInstruction* instruction, int64 operand_number,
234 const HloComputation* computation, int64 parameter_number) {
235 const HloInstruction* operand = instruction->operand(operand_number);
236 const HloInstruction* parameter =
237 computation->parameter_instruction(parameter_number);
238 if (!ShapesSame(operand->shape(), parameter->shape())) {
239 return InternalError("Operand %s shape does not match parameter's %s in %s",
240 operand->ToString(), parameter->ToString(),
241 instruction->ToString());
242 }
243 return Status::OK();
244 }
245
HandleInfeed(HloInstruction * instruction)246 Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
247 HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
248 TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
249
250 // The output of infeed is a tuple containing the data value and a token.
251 return CheckShape(infeed,
252 ShapeUtil::MakeTupleShape(
253 {infeed->infeed_shape(), ShapeUtil::MakeTokenShape()}));
254 }
255
HandleOutfeed(HloInstruction * instruction)256 Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
257 HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
258 TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
259
260 // Outfeed has a separate shape field for the value which is outfed to the
261 // host. The shape of the instruction itself is always a token.
262 if (!ShapesSame(outfeed->outfeed_shape(), outfeed->operand(0)->shape())) {
263 return InternalError(
264 "Expected outfeed shape to be equal to operand's shape %s, "
265 "actual shape is %s:\n%s",
266 StringifyShape(outfeed->operand(0)->shape()),
267 StringifyShape(outfeed->outfeed_shape()), outfeed->ToString());
268 }
269 return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
270 }
271
HasCompatibleElementTypes(const Shape & shape_0,const Shape & shape_1,const Shape & result_shape)272 bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
273 const Shape& shape_1,
274 const Shape& result_shape) {
275 return ShapeUtil::SameElementType(shape_0, shape_1) &&
276 (ShapeUtil::SameElementType(shape_0, result_shape) ||
277 (allow_mixed_precision_ &&
278 ShapeUtil::SameElementTypeIgnoringFpPrecision(shape_0,
279 result_shape)));
280 }
281
HandleRng(HloInstruction * instruction)282 Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
283 TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
284
285 const Shape& shape_0 = instruction->operand(0)->shape();
286 const Shape& shape_1 = instruction->operand(1)->shape();
287 if (!ShapeUtil::IsScalar(shape_0) || !ShapeUtil::IsScalar(shape_1)) {
288 return InternalError(
289 "Expected scalar types for the two operands of Rng instruction: %s",
290 instruction->ToString());
291 }
292
293 if (!HasCompatibleElementTypes(shape_0, shape_1, instruction->shape())) {
294 return InternalError(
295 "Expected compatible element types for the result and the two operands"
296 " of Rng instruction: %s",
297 instruction->ToString());
298 }
299
300 PrimitiveType element_type = shape_0.element_type();
301 switch (instruction->random_distribution()) {
302 case RNG_UNIFORM:
303 if (!primitive_util::IsFloatingPointType(element_type) &&
304 !primitive_util::IsIntegralType(element_type) &&
305 element_type != PRED) {
306 return InternalError(
307 "Element type not supported."
308 " Expected element to be of floating point type, integral type or"
309 " predicate type for RngUniform: %s",
310 instruction->ToString());
311 }
312 break;
313
314 case RNG_NORMAL:
315 if (!primitive_util::IsFloatingPointType(element_type)) {
316 return InternalError(
317 "Element type not supported."
318 " Expected element to be FloatingPointType for RngNormal: %s",
319 instruction->ToString());
320 }
321 break;
322 default:
323 return InternalError(
324 "Invalid Rng distribution %s",
325 RandomDistribution_Name(instruction->random_distribution()));
326 }
327
328 return Status::OK();
329 }
330
HandleReverse(HloInstruction * reverse)331 Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
332 return CheckShape(
333 reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
334 reverse->dimensions()));
335 }
336
HandleSort(HloInstruction * sort)337 Status ShapeVerifier::HandleSort(HloInstruction* sort) {
338 if (sort->operand_count() < 1) {
339 return InternalError("Expected at least 1 operand for %s instruction: %s",
340 HloOpcodeString(sort->opcode()), sort->ToString());
341 }
342 HloComputation* compare = sort->to_apply();
343
344 // Check that the 'compare' computation returns a PRED.
345 Shape compare_shape = compare->root_instruction()->shape();
346 if (!ShapesSame(compare_shape, ShapeUtil::MakeShape(PRED, {}))) {
347 return InternalError(
348 "The Sort compare computation shape does not lead to a scalar "
349 "predicate shape: %s",
350 StringifyShape(compare_shape));
351 }
352
353 // Check that the number of parameters of the 'compare' computation is
354 // correct.
355 TF_RETURN_IF_ERROR(
356 CheckParameterCount(sort, compare, sort->operand_count() * 2));
357
358 // Verify that the operands of the compare computation have the correct scalar
359 // shapes.
360 for (int64 parameter_idx = 0; parameter_idx < compare->num_parameters();
361 ++parameter_idx) {
362 int64 operand_idx = parameter_idx / 2;
363 Shape expected_scalar_shape = ShapeUtil::MakeShape(
364 sort->operand(operand_idx)->shape().element_type(), {});
365 Shape actual_parameter_shape =
366 compare->parameter_instruction(parameter_idx)->shape();
367 if (!ShapeUtil::CompatibleIgnoringFpPrecision(expected_scalar_shape,
368 actual_parameter_shape)) {
369 return InternalError(
370 "Expected the %lld-th parameter of the compare computation of sort "
371 "to have shape %s, but got %s",
372 parameter_idx, StringifyShape(expected_scalar_shape),
373 StringifyShape(actual_parameter_shape));
374 }
375 }
376
377 // Verify that all operand shapes have the same dimensions.
378 for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
379 if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
380 sort->operand(operand)->shape())) {
381 return InternalError(
382 "Expected sort to have to have the same dimensions for all operands. "
383 "First operand shape is: %s\n, shape (operand index %lld) is: %s",
384 StringifyShape(sort->operand(0)->shape()), operand,
385 StringifyShape(sort->operand(operand)->shape()));
386 }
387 }
388 return CheckVariadicShape(sort);
389 }
390
HandleConstant(HloInstruction * constant)391 Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
392 if (!Cast<HloConstantInstruction>(constant)->HasLiteral()) {
393 return InternalError("Constant is required to have a valid literal: %s",
394 constant->ToString());
395 }
396 return CheckShape(constant, constant->literal().shape());
397 }
398
HandleIota(HloInstruction * instruction)399 Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
400 auto* iota = Cast<HloIotaInstruction>(instruction);
401 if (!iota->shape().IsArray()) {
402 return InternalError("Iota does not support non-array result.");
403 }
404 const int64 rank = iota->shape().rank();
405 if (rank == 0) {
406 return InternalError("Iota does not support scalars.");
407 }
408 int64 iota_dimension = iota->iota_dimension();
409 if (iota_dimension >= rank) {
410 return InternalError(
411 "The iota dimension cannot go beyond the operation rank.");
412 }
413 return Status::OK();
414 }
415
HandleGetTupleElement(HloInstruction * get_tuple_element)416 Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
417 return CheckShape(get_tuple_element,
418 ShapeInference::InferGetTupleElementShape(
419 get_tuple_element->operand(0)->shape(),
420 get_tuple_element->tuple_index()));
421 }
422
423 namespace {
SameElementTypesForOperandsAndToApplyParameters(const HloInstruction & instruction,int64 num_operands_to_check)424 Status SameElementTypesForOperandsAndToApplyParameters(
425 const HloInstruction& instruction, int64 num_operands_to_check) {
426 const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
427 for (int i = 0; i < num_operands_to_check; ++i) {
428 const Shape& parameter_shape = to_apply.parameters(i);
429 const Shape& operand_shape = instruction.operands()[i]->shape();
430 if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
431 return InvalidArgument(
432 "Shape mismatch between to_apply computation"
433 " parameter and operand %d in %s.",
434 i, instruction.ToString().c_str());
435 }
436 }
437 return Status::OK();
438 }
439 } // namespace
440
HandleReduce(HloInstruction * reduce)441 Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
442 if (reduce->operand_count() % 2 != 0) {
443 return InternalError(
444 "Expected an even number of operands for %s instruction: %s",
445 HloOpcodeString(reduce->opcode()), reduce->ToString());
446 }
447
448 std::vector<const Shape*> operand_shapes;
449 for (const HloInstruction* operand : reduce->operands()) {
450 operand_shapes.push_back(&operand->shape());
451 }
452 TF_RETURN_IF_ERROR(
453 CheckShape(reduce, ShapeInference::InferReduceShape(
454 operand_shapes, reduce->dimensions(),
455 reduce->to_apply()->ComputeProgramShape())));
456
457 return allow_mixed_precision_
458 ? Status::OK()
459 : SameElementTypesForOperandsAndToApplyParameters(
460 *reduce, reduce->operands().size() - 1);
461 }
462
HandleBitcast(HloInstruction * bitcast)463 Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
464 // Bitcasts are not allowed to change the element type.
465 if (bitcast->operand(0)->shape().element_type() !=
466 bitcast->shape().element_type()) {
467 return InternalError(
468 "Bitcast can not change the element type from %s to %s",
469 PrimitiveType_Name(bitcast->operand(0)->shape().element_type()),
470 PrimitiveType_Name(bitcast->shape().element_type()));
471 }
472 return Status::OK();
473 }
474
HandleBroadcast(HloInstruction * broadcast)475 Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
476 // HLO broadcast has no exact analog at the proto level so there is no
477 // ShapeInference method. Check the output shape explicitly.
478 const Shape& operand_shape = broadcast->operand(0)->shape();
479 // Check for mixed precision.
480 TF_RET_CHECK(SameElementType(broadcast->shape(), operand_shape));
481 TF_RET_CHECK(operand_shape.rank() == broadcast->dimensions().size());
482 for (int64 operand_dimension = 0; operand_dimension < operand_shape.rank();
483 ++operand_dimension) {
484 int64 output_dimension = broadcast->dimensions()[operand_dimension];
485 TF_RET_CHECK((output_dimension < broadcast->shape().rank()) &&
486 output_dimension >= 0 &&
487 (broadcast->shape().dimensions(output_dimension) ==
488 operand_shape.dimensions(operand_dimension)))
489 << broadcast->ToString() << " operand shape " << operand_shape;
490 }
491 return Status::OK();
492 }
493
HandleReshape(HloInstruction * reshape)494 Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
495 // Check for mixed precision.
496 const Shape& operand_shape = reshape->operand(0)->shape();
497 TF_RET_CHECK(SameElementType(reshape->shape(), operand_shape));
498 TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
499 ShapeUtil::ElementsIn(operand_shape));
500 return Status::OK();
501 }
502
HandleTranspose(HloInstruction * transpose)503 Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
504 return CheckShape(
505 transpose, ShapeInference::InferTransposeShape(
506 transpose->operand(0)->shape(), transpose->dimensions()));
507 }
508
HandleParameter(HloInstruction * hlo)509 Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
510 return Status::OK();
511 }
512
HandleFusion(HloInstruction * fusion)513 Status ShapeVerifier::HandleFusion(HloInstruction* fusion) {
514 auto& fused_parameters = fusion->fused_parameters();
515 if (fused_parameters.size() != fusion->operand_count()) {
516 return InternalError(
517 "Fused parameter count (%d) does not match the number of operands (%d)"
518 " passed to the fusion instruction in: %s.",
519 fused_parameters.size(), fusion->operand_count(),
520 fusion->ToString().c_str());
521 }
522 for (HloInstruction* fused_param : fused_parameters) {
523 int64 param_no = fused_param->parameter_number();
524 if (!ShapesSame(fused_param->shape(), fusion->operand(param_no)->shape())) {
525 return InternalError(
526 "Shape mismatch between parameter number %d and its operand in "
527 "%s.",
528 param_no, fusion->ToString().c_str());
529 }
530 }
531 return Status::OK();
532 }
533
HandleCall(HloInstruction * call)534 Status ShapeVerifier::HandleCall(HloInstruction* call) {
535 TF_RETURN_IF_ERROR(
536 CheckParameterCount(call, call->to_apply(), call->operand_count()));
537 for (int64 i = 0; i < call->to_apply()->num_parameters(); ++i) {
538 TF_RETURN_IF_ERROR(CheckOperandAndParameter(call, i, call->to_apply(), i));
539 }
540 // The shape of kCall should match the shape of the computation it calls.
541 return CheckShape(call, call->to_apply()->root_instruction()->shape());
542 }
543
HandleCustomCall(HloInstruction * instruction)544 Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
545 const HloCustomCallInstruction* custom_call =
546 DynCast<const HloCustomCallInstruction>(instruction);
547 TF_RET_CHECK(custom_call != nullptr);
548 if (custom_call->layout_constrained()) {
549 // If the layout is constrained, verify all the respective shapes have
550 // layouts and that the constrained operand shapes match the shapes of the
551 // operands.
552 TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
553 TF_RET_CHECK(custom_call->operand_count() ==
554 custom_call->operand_shapes_with_layout().size());
555 for (int64 i = 0; i < custom_call->operand_count(); ++i) {
556 const Shape& operand_shape_with_layout =
557 custom_call->operand_shapes_with_layout()[i];
558 TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
559 operand_shape_with_layout))
560 << custom_call->operand(i)->shape().ToString() << " operand "
561 << operand_shape_with_layout.ToString();
562 TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
563 }
564 }
565 return Status::OK();
566 }
567
HandleSlice(HloInstruction * slice)568 Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
569 return CheckShape(slice,
570 ShapeInference::InferSliceShape(
571 slice->operand(0)->shape(), slice->slice_starts(),
572 slice->slice_limits(), slice->slice_strides()));
573 }
574
HandleDynamicSlice(HloInstruction * dynamic_slice)575 Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
576 return CheckShape(
577 dynamic_slice,
578 ShapeInference::InferDynamicSliceShape(
579 dynamic_slice->operand(0)->shape(),
580 Cast<HloDynamicSliceInstruction>(dynamic_slice)->index_shapes(),
581 dynamic_slice->dynamic_slice_sizes()));
582 }
583
HandleDynamicUpdateSlice(HloInstruction * dynamic_update_slice)584 Status ShapeVerifier::HandleDynamicUpdateSlice(
585 HloInstruction* dynamic_update_slice) {
586 return CheckShape(
587 dynamic_update_slice,
588 ShapeInference::InferDynamicUpdateSliceShape(
589 dynamic_update_slice->operand(0)->shape(),
590 dynamic_update_slice->operand(1)->shape(),
591 Cast<HloDynamicUpdateSliceInstruction>(dynamic_update_slice)
592 ->index_shapes()));
593 }
594
HandleTuple(HloInstruction * tuple)595 Status ShapeVerifier::HandleTuple(HloInstruction* tuple) {
596 return CheckVariadicShape(tuple);
597 }
598
HandleMap(HloInstruction * map)599 Status ShapeVerifier::HandleMap(HloInstruction* map) {
600 std::vector<const Shape*> operand_shapes;
601 int64 max_operand_rank = 0;
602 for (const HloInstruction* operand : map->operands()) {
603 operand_shapes.push_back(&operand->shape());
604 max_operand_rank = std::max(max_operand_rank, operand->shape().rank());
605 }
606 // TODO(b/65689298) Remove code below once Map is generalized to accept
607 // arbitrary map dimensions.
608 std::vector<int64> map_dims(max_operand_rank);
609 std::iota(map_dims.begin(), map_dims.end(), 0);
610
611 TF_RETURN_IF_ERROR(CheckShape(
612 map,
613 ShapeInference::InferMapShape(
614 operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
615
616 return allow_mixed_precision_
617 ? Status::OK()
618 : SameElementTypesForOperandsAndToApplyParameters(
619 *map, map->operands().size());
620 }
621
HandleReduceWindow(HloInstruction * reduce_window)622 Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
623 TF_RETURN_IF_ERROR(CheckShape(
624 reduce_window,
625 ShapeInference::InferReduceWindowShape(
626 reduce_window->operand(0)->shape(),
627 reduce_window->operand(1)->shape(), reduce_window->window(),
628 reduce_window->to_apply()->ComputeProgramShape())));
629
630 return allow_mixed_precision_
631 ? Status::OK()
632 : SameElementTypesForOperandsAndToApplyParameters(*reduce_window,
633 1);
634 }
635
HandleSelectAndScatter(HloInstruction * instruction)636 Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
637 return CheckShape(
638 instruction,
639 ShapeInference::InferSelectAndScatterShape(
640 instruction->operand(0)->shape(),
641 instruction->select()->ComputeProgramShape(), instruction->window(),
642 instruction->operand(1)->shape(), instruction->operand(2)->shape(),
643 instruction->scatter()->ComputeProgramShape()));
644 }
645
HandleWhile(HloInstruction * xla_while)646 Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
647 TF_RETURN_IF_ERROR(
648 CheckParameterCount(xla_while, xla_while->while_body(), 1));
649 TF_RETURN_IF_ERROR(
650 CheckParameterCount(xla_while, xla_while->while_condition(), 1));
651 TF_RETURN_IF_ERROR(
652 CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
653 TF_RETURN_IF_ERROR(
654 CheckOperandAndParameter(xla_while, 0, xla_while->while_condition(), 0));
655 const Shape& conditional_shape =
656 xla_while->while_condition()->root_instruction()->shape();
657 if (!ShapesSame(conditional_shape, ShapeUtil::MakeShape(PRED, {}))) {
658 return InternalError(
659 "Conditional computation shape does not lead to a scalar predicate "
660 "shape: %s",
661 StringifyShape(conditional_shape));
662 }
663 // The shape of kWhile should match the shape of the body computation it
664 // calls.
665 return CheckShape(xla_while,
666 xla_while->while_body()->root_instruction()->shape());
667 }
668
HandleConditional(HloInstruction * conditional)669 Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
670 const int num_branches = conditional->branch_count();
671 if (conditional->operand(0)->shape().element_type() == PRED) {
672 TF_RET_CHECK(num_branches == 2);
673 } else {
674 TF_RET_CHECK(num_branches >= 1);
675 }
676 TF_RETURN_IF_ERROR(CheckOperandCount(conditional, num_branches + 1));
677 for (int j = 0; j < num_branches; ++j) {
678 TF_RETURN_IF_ERROR(CheckParameterCount(
679 conditional, conditional->branch_computation(j), 1));
680 TF_RETURN_IF_ERROR(CheckOperandAndParameter(
681 conditional, j + 1, conditional->branch_computation(j), 0));
682 TF_RETURN_IF_ERROR(CheckShape(
683 conditional,
684 conditional->branch_computation(j)->root_instruction()->shape()));
685 }
686 return Status::OK();
687 }
688
HandlePad(HloInstruction * pad)689 Status ShapeVerifier::HandlePad(HloInstruction* pad) {
690 return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
691 pad->operand(1)->shape(),
692 pad->padding_config()));
693 }
694
HandleSend(HloInstruction * send)695 Status ShapeVerifier::HandleSend(HloInstruction* send) {
696 return CheckShape(send,
697 ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
698 ShapeUtil::MakeShape(U32, {}),
699 ShapeUtil::MakeTokenShape()}));
700 }
701
HandleSendDone(HloInstruction * send_done)702 Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
703 return CheckShape(send_done, ShapeUtil::MakeTokenShape());
704 }
705
HandleRecv(HloInstruction * recv)706 Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
707 return CheckShape(
708 recv, ShapeUtil::MakeTupleShape(
709 {ShapeUtil::GetTupleElementShape(recv->shape(), 0),
710 ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()}));
711 }
712
HandleRecvDone(HloInstruction * recv_done)713 Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
714 return CheckShape(
715 recv_done,
716 ShapeUtil::MakeTupleShape(
717 {ShapeUtil::GetTupleElementShape(recv_done->operand(0)->shape(), 0),
718 ShapeUtil::MakeTokenShape()}));
719 }
720
HandleBatchNormTraining(HloInstruction * batch_norm_training)721 Status ShapeVerifier::HandleBatchNormTraining(
722 HloInstruction* batch_norm_training) {
723 return CheckShape(batch_norm_training,
724 ShapeInference::InferBatchNormTrainingShape(
725 batch_norm_training->operand(0)->shape(),
726 batch_norm_training->operand(1)->shape(),
727 batch_norm_training->operand(2)->shape(),
728 batch_norm_training->feature_index()));
729 }
730
HandleBatchNormInference(HloInstruction * batch_norm_inference)731 Status ShapeVerifier::HandleBatchNormInference(
732 HloInstruction* batch_norm_inference) {
733 return CheckShape(batch_norm_inference,
734 ShapeInference::InferBatchNormInferenceShape(
735 batch_norm_inference->operand(0)->shape(),
736 batch_norm_inference->operand(1)->shape(),
737 batch_norm_inference->operand(2)->shape(),
738 batch_norm_inference->operand(3)->shape(),
739 batch_norm_inference->operand(4)->shape(),
740 batch_norm_inference->feature_index()));
741 }
742
HandleBatchNormGrad(HloInstruction * batch_norm_grad)743 Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
744 return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
745 batch_norm_grad->operand(0)->shape(),
746 batch_norm_grad->operand(1)->shape(),
747 batch_norm_grad->operand(2)->shape(),
748 batch_norm_grad->operand(3)->shape(),
749 batch_norm_grad->operand(4)->shape(),
750 batch_norm_grad->feature_index()));
751 }
752
753 namespace {
754
755 // Checks that the instruction does not have mixed precision floating point
756 // inputs.
CheckMixedPrecisionOperands(const HloInstruction * instruction)757 Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
758 switch (instruction->opcode()) {
759 // White list the following opcodes for mixed-precision check, because
760 // they involve data pass through or grouping via tuples, where the
761 // precisions of buffers can be different.
762 case HloOpcode::kCall:
763 case HloOpcode::kConditional:
764 case HloOpcode::kConstant:
765 case HloOpcode::kAllReduce:
766 case HloOpcode::kCustomCall:
767 case HloOpcode::kDomain:
768 case HloOpcode::kFusion:
769 case HloOpcode::kGetTupleElement:
770 case HloOpcode::kInfeed:
771 case HloOpcode::kOutfeed:
772 case HloOpcode::kParameter:
773 case HloOpcode::kRecv:
774 case HloOpcode::kRecvDone:
775 case HloOpcode::kReducePrecision:
776 case HloOpcode::kTupleSelect:
777 case HloOpcode::kSend:
778 case HloOpcode::kSendDone:
779 case HloOpcode::kSort:
780 case HloOpcode::kTuple:
781 case HloOpcode::kWhile:
782 break;
783 default: {
784 PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
785 for (auto operand : instruction->operands()) {
786 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
787 operand->shape(),
788 [&](const Shape& subshape, const ShapeIndex& index) {
789 if (!ShapeUtil::ElementIsFloating(subshape)) {
790 return Status::OK();
791 }
792 if (fp_type == PRIMITIVE_TYPE_INVALID) {
793 fp_type = subshape.element_type();
794 } else if (fp_type != subshape.element_type()) {
795 return InternalError(
796 "Seen floating point types of different precisions in "
797 "%s, but mixed precision is disallowed.",
798 instruction->ToString());
799 }
800 return Status::OK();
801 }));
802 }
803 }
804 }
805 return Status::OK();
806 }
807
808 } // namespace
809
HandleGather(HloInstruction * gather)810 Status ShapeVerifier::HandleGather(HloInstruction* gather) {
811 return CheckShape(
812 gather,
813 ShapeInference::InferGatherShape(
814 gather->operand(0)->shape(), gather->operand(1)->shape(),
815 gather->gather_dimension_numbers(), gather->gather_slice_sizes()));
816 }
817
HandleScatter(HloInstruction * scatter)818 Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
819 return CheckShape(
820 scatter, ShapeInference::InferScatterShape(
821 scatter->operand(0)->shape(), scatter->operand(1)->shape(),
822 scatter->operand(2)->shape(),
823 scatter->to_apply()->ComputeProgramShape(),
824 scatter->scatter_dimension_numbers()));
825 }
826
HandleAfterAll(HloInstruction * token)827 Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
828 std::vector<const Shape*> operand_shapes;
829 for (const HloInstruction* operand : token->operands()) {
830 operand_shapes.push_back(&operand->shape());
831 }
832 return CheckShape(token, ShapeUtil::MakeTokenShape());
833 }
834
HandleAddDependency(HloInstruction * add_dependency)835 Status ShapeVerifier::HandleAddDependency(HloInstruction* add_dependency) {
836 TF_RETURN_IF_ERROR(CheckIsTokenOperand(add_dependency, 1));
837 return CheckShape(add_dependency, add_dependency->operand(0)->shape());
838 }
839
HandleGetDimensionSize(HloInstruction * get_size)840 Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) {
841 return CheckShape(get_size,
842 ShapeInference::InferGetDimensionSizeShape(
843 get_size->operand(0)->shape(), get_size->dimension()));
844 }
845
CheckShape(const HloInstruction * instruction,const Shape & inferred_shape)846 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
847 const Shape& inferred_shape) {
848 // If allow_mixed_precision_ is false, check if there are operands with
849 // different precisions. We need this check because ShapeInference allows
850 // mixed precision inputs.
851 if (!allow_mixed_precision_) {
852 TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
853 }
854
855 // Check if the output shape matches the expected shape.
856 //
857 // We treat BF16 and F32 as compatible types if mixed precision is allowed,
858 // but only when the instruction defines the BF16/F32 buffer.
859 bool equal = [&] {
860 switch (instruction->opcode()) {
861 // The opcodes below can't have implicit layout conversions, nor can they
862 // implicitly transform f32 -> bf16. Fundamentally these are either
863 // reinterpreting existing data (e.g. kBitcast) or shuffling data around
864 // without modifying it (e.g. kGetTupleElement, kTupleSelect).
865 case HloOpcode::kBitcast:
866 case HloOpcode::kCall:
867 case HloOpcode::kConditional:
868 case HloOpcode::kConstant:
869 case HloOpcode::kCustomCall:
870 case HloOpcode::kGetTupleElement:
871 case HloOpcode::kInfeed:
872 case HloOpcode::kOutfeed:
873 case HloOpcode::kParameter:
874 case HloOpcode::kRecv:
875 case HloOpcode::kRecvDone:
876 case HloOpcode::kSend:
877 case HloOpcode::kSendDone:
878 case HloOpcode::kTuple:
879 case HloOpcode::kTupleSelect:
880 case HloOpcode::kWhile:
881 return ShapesSame(instruction->shape(), inferred_shape);
882
883 // We allow arbitrary layout and f32->bf16 transformations on all other
884 // instructions, although this may be made more strict pending discussion
885 // in b/112709536.
886 default:
887 if (allow_mixed_precision_) {
888 return ShapeUtil::CompatibleIgnoringFpPrecision(instruction->shape(),
889 inferred_shape);
890 } else {
891 return ShapeUtil::Compatible(instruction->shape(), inferred_shape);
892 }
893 }
894 }();
895 if (!equal) {
896 return InternalError(
897 "Expected instruction to have shape equal to %s, actual "
898 "shape is %s:\n%s",
899 StringifyShape(inferred_shape), StringifyShape(instruction->shape()),
900 instruction->ToString());
901 }
902 return Status::OK();
903 }
904
CheckShape(const HloInstruction * instruction,const StatusOr<Shape> & inferred_shape_status)905 Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
906 const StatusOr<Shape>& inferred_shape_status) {
907 if (!inferred_shape_status.ok()) {
908 Status s = inferred_shape_status.status();
909 tensorflow::errors::AppendToMessage(&s, ", for instruction ",
910 instruction->ToString());
911 return s;
912 }
913 return CheckShape(instruction, inferred_shape_status.ValueOrDie());
914 }
915
CheckUnaryShape(const HloInstruction * instruction)916 Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
917 return CheckShape(instruction,
918 ShapeInference::InferUnaryOpShape(instruction->opcode(),
919 instruction->operand(0)));
920 }
921
CheckBinaryShape(const HloInstruction * instruction)922 Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
923 return CheckShape(
924 instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
925 instruction->operand(0),
926 instruction->operand(1)));
927 }
928
CheckTernaryShape(const HloInstruction * instruction)929 Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
930 return CheckShape(instruction,
931 ShapeInference::InferTernaryOpShape(
932 instruction->opcode(), instruction->operand(0),
933 instruction->operand(1), instruction->operand(2)));
934 }
935
CheckVariadicShape(const HloInstruction * instruction)936 Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
937 return CheckShape(instruction,
938 ShapeInference::InferVariadicOpShape(
939 instruction->opcode(), instruction->operands()));
940 }
941
VerifyEntryComputationLayout(const HloModule & module)942 Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
943 const HloComputation* computation = module.entry_computation();
944 const auto& layout = module.entry_computation_layout();
945 const ShapeLayout& result_layout = layout.result_layout();
946
947 TF_RETURN_IF_ERROR(
948 ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));
949
950 TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape()));
951
952 if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
953 result_layout.shape())) {
954 return InternalError(
955 "Shape of the root instruction of entry computation (%s) should be "
956 "compatible to one specified in module's entry computation layout (%s)",
957 ShapeUtil::HumanString(computation->root_instruction()->shape()),
958 ShapeUtil::HumanString(result_layout.shape()));
959 }
960
961 if (computation->num_parameters() != layout.parameter_count()) {
962 return InternalError(
963 "Number of parameters in entry computation layout (%d) must be same "
964 "as number of parameters of entry computation computation (%d)",
965 layout.parameter_count(), computation->num_parameters());
966 }
967
968 for (int i = 0; i < computation->num_parameters(); ++i) {
969 const HloInstruction* parameter = computation->parameter_instruction(i);
970 TF_RETURN_IF_ERROR(
971 ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
972 TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i)));
973 if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
974 return InternalError(
975 "Shape of the entry computation parameter %d is %s should be "
976 "compatible to the one specified in module's entry computation "
977 "layout %s",
978 i, ShapeUtil::HumanString(parameter->shape()),
979 ShapeUtil::HumanString(layout.parameter_shape(i)));
980 }
981 }
982
983 return Status::OK();
984 }
985
ComputationsToString(absl::Span<HloComputation * const> computations)986 string ComputationsToString(absl::Span<HloComputation* const> computations) {
987 return absl::StrJoin(computations, ",",
988 [](string* s, const HloComputation* computation) {
989 s->append(computation->name());
990 });
991 }
992
993 // Verifies various invariants about the structure of the HLO:
994 //
995 // (1) each instruction has a non-null parent() set to the HloComputation
996 // which
997 // contains it.
998 //
999 // (2) each computation has a non-null parent() set to the HloModule which
1000 // contains it.
1001 //
1002 // (3) the operands of each instruction are in the same computation as the
1003 // instruction.
VerifyHloStructure(HloModule * module)1004 Status VerifyHloStructure(HloModule* module) {
1005 for (const HloComputation* computation : module->computations()) {
1006 if (computation->parent() == nullptr) {
1007 return InternalError("Computation %s has a null parent pointer",
1008 computation->name());
1009 }
1010 if (computation->parent() != module) {
1011 return InternalError(
1012 "Computation %s parent() does not point to parent module",
1013 computation->name());
1014 }
1015
1016 for (const HloInstruction* instruction : computation->instructions()) {
1017 if (instruction->parent() == nullptr) {
1018 return InternalError("Instruction %s has a null parent pointer",
1019 instruction->name());
1020 }
1021 if (instruction->parent() != computation) {
1022 return InternalError(
1023 "Instruction %s parent() does not point to parent computation",
1024 instruction->name());
1025 }
1026 }
1027 }
1028
1029 // Check that operands are in the same computation separately from verifying
1030 // parent() correctness so conditions like a null HloInstruction::parent()
1031 // are identified and reported explicitly above rather than reporting a
1032 // mismatched operand.
1033 for (const HloComputation* computation : module->computations()) {
1034 for (const HloInstruction* instruction : computation->instructions()) {
1035 for (int i = 0; i < instruction->operand_count(); ++i) {
1036 const HloInstruction* operand = instruction->operand(i);
1037 if (operand->parent() != instruction->parent()) {
1038 return InternalError(
1039 "Operand %d (%s) of instruction %s is in a different "
1040 "computation: %s vs %s",
1041 i, operand->name(), instruction->name(),
1042 operand->parent()->name(), instruction->parent()->name());
1043 }
1044 }
1045 }
1046 }
1047 return Status::OK();
1048 }
1049
1050 namespace {
1051
1052 // Returns true if the given Shape has a TOKEN shape as any subshape.
ShapeContainsToken(const Shape & shape)1053 bool ShapeContainsToken(const Shape& shape) {
1054 bool contains_token = false;
1055 ShapeUtil::ForEachSubshape(
1056 shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
1057 if (subshape.IsToken()) {
1058 contains_token = true;
1059 }
1060 });
1061 return contains_token;
1062 }
1063
1064 // Verifies that all types entering and exiting the entry computation are
1065 // legal.
VerifyEntryAndExitShapes(const HloModule & module)1066 Status VerifyEntryAndExitShapes(const HloModule& module) {
1067 // Tokens cannot be passed as entry parameters.
1068 // TODO(b/80000000): Remove this constraint.
1069 for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
1070 HloInstruction* param =
1071 module.entry_computation()->parameter_instruction(i);
1072 if (ShapeContainsToken(param->shape())) {
1073 return InternalError(
1074 "Entry parameter %d is or contains a token shape: %s", i,
1075 ShapeUtil::HumanString(param->shape()));
1076 }
1077 }
1078 return Status::OK();
1079 }
1080
1081 // Checks if the given two instructions share the same channel id.
CheckSameChannel(const HloInstruction * instr1,const HloInstruction * instr2)1082 Status CheckSameChannel(const HloInstruction* instr1,
1083 const HloInstruction* instr2) {
1084 if (instr1->channel_id() != instr2->channel_id()) {
1085 return InternalError(
1086 "Expected to have the same channel id, actual channel ids are: %s "
1087 "(%d), %s (%d)",
1088 instr1->ToString(), instr1->channel_id(), instr2->ToString(),
1089 instr2->channel_id());
1090 }
1091 return Status::OK();
1092 }
1093
1094 // Checks if the given two instructions have the same is_host_transfer
1095 // attribute value. Intsructions must be send/recv instructions or their
1096 // 'done' variant.
CheckSameIsHostTransfer(const HloInstruction * instr1,const HloInstruction * instr2)1097 Status CheckSameIsHostTransfer(const HloInstruction* instr1,
1098 const HloInstruction* instr2) {
1099 const HloSendRecvInstruction* send_recv1 =
1100 DynCast<const HloSendRecvInstruction>(instr1);
1101 const HloSendRecvInstruction* send_recv2 =
1102 DynCast<const HloSendRecvInstruction>(instr2);
1103 TF_RET_CHECK(send_recv1 != nullptr);
1104 TF_RET_CHECK(send_recv2 != nullptr);
1105 if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
1106 return InternalError(
1107 "Expected instructions to have the same is-host-transfer property: "
1108 "%s, "
1109 "%s ",
1110 instr1->ToString(), instr2->ToString());
1111 }
1112 return Status::OK();
1113 }
1114
1115 // Checks various invariants of send and recv instructions.
VerifySendsAndRecvs(const HloModule & module)1116 Status VerifySendsAndRecvs(const HloModule& module) {
1117 absl::flat_hash_map<int64, const HloInstruction*> host_channels;
1118 // Host send/recv instructions must have their own unique channel.
1119 auto check_unique_host_channel = [&](const HloInstruction* instruction) {
1120 const HloSendRecvInstruction* sendrecv =
1121 DynCast<const HloSendRecvInstruction>(instruction);
1122 if (sendrecv->is_host_transfer()) {
1123 auto it_inserted =
1124 host_channels.insert({sendrecv->channel_id(), sendrecv});
1125 if (!it_inserted.second) {
1126 return FailedPrecondition(
1127 "Channel %d is used for multiple host send/recv instructions: "
1128 "%s "
1129 "and "
1130 "%s",
1131 sendrecv->channel_id(), sendrecv->ToString(),
1132 it_inserted.first->second->ToString());
1133 }
1134 }
1135
1136 return Status::OK();
1137 };
1138
1139 // Send/Recv instruction must have a single user: the corresponding
1140 // SendDone/RecvDone. with matching channel.
1141 for (const HloComputation* computation : module.computations()) {
1142 for (const HloInstruction* instruction : computation->instructions()) {
1143 switch (instruction->opcode()) {
1144 case HloOpcode::kSend: {
1145 TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
1146 TF_RET_CHECK(instruction->users().size() == 1);
1147 const HloInstruction* send_done = instruction->users().front();
1148 TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
1149 TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
1150 TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
1151 break;
1152 }
1153 case HloOpcode::kRecv: {
1154 TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
1155 TF_RET_CHECK(instruction->users().size() == 1);
1156 const HloInstruction* recv_done = instruction->users().front();
1157 TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
1158 TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
1159 TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
1160 break;
1161 }
1162 case HloOpcode::kSendDone:
1163 TF_RET_CHECK(instruction->operands().size() == 1);
1164 TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
1165 break;
1166 case HloOpcode::kRecvDone:
1167 TF_RET_CHECK(instruction->operands().size() == 1);
1168 TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
1169 break;
1170 default:
1171 break;
1172 }
1173 }
1174 }
1175 return Status::OK();
1176 }
1177
1178 // CHECKs various invariants of a fusion instruction.
CheckFusionInstruction(HloInstruction * fusion)1179 Status CheckFusionInstruction(HloInstruction* fusion) {
1180 // The parent fusion instruction of the fusion computation must be 'fusion'.
1181 HloComputation* fused_computation = fusion->fused_instructions_computation();
1182 if (fusion != fused_computation->FusionInstruction()) {
1183 return InternalError(
1184 "Instruction of fused computation does not match expected "
1185 "instruction "
1186 "%s.",
1187 fusion->ToString());
1188 }
1189
1190 // Fused root instruction and fused parameters must all be owned by the
1191 // fusion computation.
1192 bool root_owned = false;
1193 const std::vector<HloInstruction*>& fused_parameters =
1194 fusion->fused_parameters();
1195 const HloInstruction* fused_root = fusion->fused_expression_root();
1196 std::vector<bool> parameter_owned(fused_parameters.size(), false);
1197 for (auto* instruction : fused_computation->instructions()) {
1198 if (fused_root == instruction) {
1199 if (root_owned) {
1200 return InternalError("Root appears more than once in %s.",
1201 fusion->ToString());
1202 }
1203 root_owned = true;
1204 }
1205 for (int i = 0; i < fused_parameters.size(); ++i) {
1206 if (fused_parameters[i] == instruction) {
1207 if (parameter_owned[i]) {
1208 return InternalError("Parameter appears more than once in %s.",
1209 fusion->ToString());
1210 }
1211 parameter_owned[i] = true;
1212 }
1213 }
1214 }
1215 if (!root_owned) {
1216 return InternalError("Root not found in computation of %s.",
1217 fusion->ToString());
1218 }
1219 // Make sure all the parameter_owned entries are set
1220 for (int i = 0; i < parameter_owned.size(); i++) {
1221 if (!parameter_owned[i]) {
1222 return InternalError("Parameter %d not found in computation of %s.", i,
1223 fusion->ToString());
1224 }
1225 }
1226
1227 // Fused root must have no users.
1228 if (fused_root->user_count() != 0) {
1229 return InternalError("Root of %s may not have users.", fusion->ToString());
1230 }
1231
1232 // All uses of fused instructions must be in the fusion computation, and
1233 // every non-root instruction must have at least one use.
1234 for (auto* instruction :
1235 fusion->fused_instructions_computation()->instructions()) {
1236 if (instruction != fused_root) {
1237 if (instruction->user_count() == 0) {
1238 return InternalError("Non-root instruction %s in %s must have users.",
1239 instruction->ToString(), fusion->ToString());
1240 }
1241 for (auto& user : instruction->users()) {
1242 if (fused_computation != user->parent()) {
1243 return InternalError(
1244 "Non-root instruction %s in %s may not have external users.",
1245 instruction->ToString(), fusion->ToString());
1246 }
1247 }
1248 }
1249 }
1250
1251 // Fused parameter instructions must be numbered contiguously and match up
1252 // (shapes equal) with their respective operand.
1253 CHECK_EQ(fusion->operands().size(), fused_parameters.size());
1254 std::vector<bool> parameter_numbers(fused_parameters.size(), false);
1255 for (auto fused_param : fused_parameters) {
1256 int64 param_no = fused_param->parameter_number();
1257 if (param_no < 0) {
1258 return InternalError("Unexpected negative parameter number %d in %s.",
1259 param_no, fusion->ToString());
1260 }
1261 if (param_no >= fused_parameters.size()) {
1262 return InternalError(
1263 "Unexpected parameter number %d in %s: higher then number of "
1264 "parameters %lu.",
1265 param_no, fusion->ToString(), fused_parameters.size());
1266 }
1267 if (parameter_numbers[param_no]) {
1268 return InternalError(
1269 "Did not expect parameter number %d more than once in %s.", param_no,
1270 fusion->ToString());
1271 }
1272 parameter_numbers[param_no] = true;
1273 }
1274 // Make sure all the parameter_numbers entries were seen.
1275 for (int i = 0; i < parameter_numbers.size(); i++) {
1276 if (!parameter_numbers[i]) {
1277 return InternalError("Did not see parameter number %d in %s.", i,
1278 fusion->ToString());
1279 }
1280 }
1281
1282 TF_RET_CHECK(fusion->called_computations() ==
1283 absl::Span<HloComputation* const>(
1284 {fusion->fused_instructions_computation()}))
1285 << "Fusion HLO calls computations other than the "
1286 "fused_instructions_computation: "
1287 << fusion->ToString() << " fusion->fused_instructions_computation(): "
1288 << fusion->fused_instructions_computation()->ToString()
1289 << " fusion->called_computations(): "
1290 << ComputationsToString(fusion->called_computations());
1291
1292 for (const auto& fused : fusion->fused_instructions()) {
1293 TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
1294 << "Fused HLO was missing a parent: " << fused->ToString()
1295 << " parent: " << fused->parent()
1296 << " computation: " << fusion->parent();
1297 }
1298
1299 // TODO(b/65423525): We'd like to check that all operands are distinct.
1300 // This is currently disabled due to the invariant being violated by
1301 // multi-output fusion.
1302 return Status::OK();
1303 }
1304
1305 // Checks that the operand shapes are compatible to the output shape, i.e.,
1306 // that there are no implicit broadcasts.
CheckElementwiseInstruction(HloInstruction * instruction)1307 Status CheckElementwiseInstruction(HloInstruction* instruction) {
1308 const Shape& out_shape = instruction->shape();
1309 for (HloInstruction* operand : instruction->operands()) {
1310 const Shape& operand_shape = operand->shape();
1311 if (!ShapeUtil::CompatibleIgnoringElementType(operand_shape, out_shape)) {
1312 return FailedPrecondition(
1313 "Implicit broadcast is not allowed in HLO."
1314 "Found different shapes for instruction %s.\n"
1315 "output: %s\noperand: %s\n",
1316 HloOpcodeString(instruction->opcode()),
1317 ShapeUtil::HumanString(out_shape),
1318 ShapeUtil::HumanString(operand_shape));
1319 }
1320 }
1321 return Status::OK();
1322 }
1323
1324 // Visitor which verifies various fields on the HLO instruction. This class does
1325 // not check result shape as that is checked in the ShapeVerifier.
1326 class InstructionVerifier : public DfsHloVisitorWithDefault {
1327 public:
InstructionVerifier(std::function<bool (const HloInstruction *)> instruction_can_change_layout_func)1328 explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
1329 instruction_can_change_layout_func)
1330 : instruction_can_change_layout_func_(
1331 instruction_can_change_layout_func) {}
1332
DefaultAction(HloInstruction *)1333 Status DefaultAction(HloInstruction*) override { return Status::OK(); }
1334
HandleFusion(HloInstruction * fusion)1335 Status HandleFusion(HloInstruction* fusion) override {
1336 return CheckFusionInstruction(fusion);
1337 }
1338
HandleBroadcast(HloInstruction * broadcast)1339 Status HandleBroadcast(HloInstruction* broadcast) override {
1340 // If you see this failure then someone has confused the difference
1341 // between the HLO broadcast op, and the UserComputation broadcast
1342 // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
1343 // or ComputationLowerer::Visit()
1344 TF_RET_CHECK(broadcast->dimensions().size() ==
1345 broadcast->operand(0)->shape().rank())
1346 << "Broadcast HLO (" << broadcast->ToShortString()
1347 << ") has invalid number of dimensions: "
1348 << broadcast->dimensions().size()
1349 << " != " << broadcast->operand(0)->shape().rank();
1350 return Status::OK();
1351 }
1352
HandleWhile(HloInstruction * xla_while)1353 Status HandleWhile(HloInstruction* xla_while) override {
1354 auto* while_cond = xla_while->while_condition();
1355 auto* while_body = xla_while->while_body();
1356 if (while_cond->num_parameters() != 1) {
1357 return FailedPrecondition(
1358 "While condition must have exactly 1 parameter; had %d : %s",
1359 while_cond->num_parameters(), while_cond->ToString());
1360 }
1361 if (while_body->num_parameters() != 1) {
1362 return FailedPrecondition(
1363 "While body must have exactly 1 parameter; had %d : %s",
1364 while_body->num_parameters(), while_body->ToString());
1365 }
1366 if (xla_while->operand_count() != 1) {
1367 return FailedPrecondition(
1368 "While loop must have exactly one operand; had %d : %s",
1369 xla_while->operand_count(), xla_while->ToString());
1370 }
1371 return Status::OK();
1372 }
1373
HandleConditional(HloInstruction * conditional)1374 Status HandleConditional(HloInstruction* conditional) override {
1375 for (int b = 0; b < conditional->branch_count(); ++b) {
1376 if (conditional->branch_computation(b)->num_parameters() != 1) {
1377 return FailedPrecondition(
1378 "Branch computation %s of %s must have 1 parameter insted of %d",
1379 conditional->branch_computation(b)->name(), conditional->ToString(),
1380 conditional->branch_computation(b)->num_parameters());
1381 }
1382 }
1383 return Status::OK();
1384 }
1385
HandleElementwiseUnary(HloInstruction * instruction)1386 Status HandleElementwiseUnary(HloInstruction* instruction) override {
1387 return CheckElementwiseInstruction(instruction);
1388 }
1389
HandleElementwiseBinary(HloInstruction * instruction)1390 Status HandleElementwiseBinary(HloInstruction* instruction) override {
1391 return CheckElementwiseInstruction(instruction);
1392 }
1393
HandleGetTupleElement(HloInstruction * gte)1394 Status HandleGetTupleElement(HloInstruction* gte) override {
1395 TF_RET_CHECK(gte->operand(0)->shape().IsTuple());
1396 return Status::OK();
1397 }
1398
HandleTranspose(HloInstruction * transpose)1399 Status HandleTranspose(HloInstruction* transpose) override {
1400 const Shape& shape = transpose->shape();
1401 const HloInstruction* operand = transpose->operand(0);
1402 TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
1403 TF_RET_CHECK(shape.dimensions().size() ==
1404 transpose->operand(0)->shape().dimensions().size());
1405 TF_RET_CHECK(std::equal(
1406 operand->shape().dimensions().begin(),
1407 operand->shape().dimensions().end(),
1408 Permute(transpose->dimensions(), shape.dimensions()).begin()))
1409 << "shape: " << shape << ", operand->shape(): " << shape
1410 << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
1411 << "}";
1412 return Status::OK();
1413 }
1414
HandleAllReduce(HloInstruction * crs)1415 Status HandleAllReduce(HloInstruction* crs) override {
1416 if (crs->all_reduce_id().has_value()) {
1417 TF_RET_CHECK(crs->all_reduce_id().value() > 0)
1418 << "All reduce id must be greater than 0 for "
1419 << crs->ToShortString();
1420 }
1421 return Status::OK();
1422 }
1423
Preprocess(HloInstruction * instruction)1424 Status Preprocess(HloInstruction* instruction) override {
1425 auto previous = instructions_by_name_.find(instruction->name());
1426 TF_RET_CHECK(previous == instructions_by_name_.end())
1427 << "HLO has name that is not unique within module:\n"
1428 << instruction->ToString()
1429 << " in computation: " << instruction->parent()->name()
1430 << "\nPrevious HLO with same name:\n"
1431 << previous->second->ToString()
1432 << " in computation: " << previous->second->parent()->name();
1433 instructions_by_name_[instruction->name()] = instruction;
1434 return Status::OK();
1435 }
1436
Postprocess(HloInstruction * instruction)1437 Status Postprocess(HloInstruction* instruction) override {
1438 if (instruction_can_change_layout_func_ &&
1439 LayoutUtil::IsDenseArray(instruction->shape()) &&
1440 !instruction_can_change_layout_func_(instruction)) {
1441 const Shape& result_shape = instruction->shape();
1442 const Layout& result_layout = result_shape.layout();
1443 for (HloInstruction* operand : instruction->operands()) {
1444 const Shape& operand_shape = operand->shape();
1445 if (LayoutUtil::IsDenseArray(operand_shape) &&
1446 operand_shape.rank() == result_shape.rank()) {
1447 const Layout& operand_layout = operand_shape.layout();
1448 TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
1449 << "Instruction shouldn't change layouts "
1450 << instruction->ToString() << " From " << result_shape << " To "
1451 << operand_shape;
1452 }
1453 }
1454 }
1455
1456 return Status::OK();
1457 }
1458
1459 private:
1460 absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
1461 // Determines whether an instruction can change layouts.
1462 std::function<bool(const HloInstruction*)>
1463 instruction_can_change_layout_func_;
1464 };
1465
1466 } // namespace
1467
Run(HloModule * module)1468 StatusOr<bool> HloVerifier::Run(HloModule* module) {
1469 TF_RET_CHECK(!module->name().empty());
1470
1471 if (module->entry_computation()->IsFusionComputation()) {
1472 return InvalidArgument(
1473 "Module entry computation cannot be a fusion computation");
1474 }
1475
1476 TF_RETURN_IF_ERROR(VerifyHloStructure(module));
1477 TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
1478
1479 std::unique_ptr<ShapeVerifier> shape_verifier =
1480 target_metadata_->GetVerifier();
1481 for (auto* computation : module->computations()) {
1482 TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
1483
1484 InstructionVerifier instruction_verifier(
1485 instruction_can_change_layout_func_);
1486 TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
1487 }
1488
1489 TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
1490 TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
1491
1492 // If the module has a schedule, it must be valid.
1493 if (module->has_schedule()) {
1494 TF_RETURN_IF_ERROR(module->schedule().Verify());
1495 }
1496
1497 TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(
1498 *module, [this](const Shape& shape) {
1499 return target_metadata_->ShapeSize(shape);
1500 }));
1501
1502 TF_RETURN_IF_ERROR(module->dynamic_parameter_binding().Verify(*module));
1503
1504 return false;
1505 }
1506
1507 } // namespace xla
1508