1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_replace.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/layout_assignment.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla.pb.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33
34 namespace xla {
35 namespace {
36
37 using ::testing::HasSubstr;
38
CreateUnverifiedModule()39 std::unique_ptr<HloModule> CreateUnverifiedModule() {
40 return std::make_unique<HloModule>("module", HloModuleConfig());
41 }
42
43 // This class cannot be converted to use HloTestBase. It explicitly
44 // uses HloTestBase to create and test malformed HLOs.
45 class HloVerifierTest : public HloTestBase {
46 public:
HloVerifierTest()47 HloVerifierTest()
48 : HloTestBase(/*verifier_layout_sensitive=*/false,
49 /*allow_mixed_precision_in_hlo_verifier=*/false) {}
50 };
51
52 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
53 public:
HloVerifierTestAllowMixedPrecision()54 HloVerifierTestAllowMixedPrecision()
55 : HloTestBase(/*verifier_layout_sensitive=*/false,
56 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
57 };
58
59 class HloVerifierTestLayoutSensitive : public HloTestBase {
60 public:
HloVerifierTestLayoutSensitive()61 HloVerifierTestLayoutSensitive()
62 : HloTestBase(/*verifier_layout_sensitive=*/true,
63 /*allow_mixed_precision_in_hlo_verifier=*/false,
64 LayoutAssignment::InstructionCanChangeLayout) {}
65 };
66
67 class HloVerifierTestLayoutFusion : public HloTestBase {
68 public:
HloVerifierTestLayoutFusion()69 HloVerifierTestLayoutFusion()
70 : HloTestBase(/*verifier_layout_sensitive=*/true,
71 /*allow_mixed_precision_in_hlo_verifier=*/false) {}
72 };
73
TEST_F(HloVerifierTest,NullInstructionParent)74 TEST_F(HloVerifierTest, NullInstructionParent) {
75 HloComputation::Builder builder(TestName());
76 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
77 HloInstruction* param = builder.AddInstruction(
78 HloInstruction::CreateParameter(0, scalar_shape, "param"));
79 HloInstruction* negate = builder.AddInstruction(
80 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
81 auto module = CreateUnverifiedModule();
82 module->AddEntryComputation(builder.Build());
83
84 TF_ASSERT_OK(verifier().Run(module.get()).status());
85
86 negate->set_parent(nullptr);
87
88 auto status = verifier().Run(module.get()).status();
89 ASSERT_FALSE(status.ok());
90 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
91 }
92
TEST_F(HloVerifierTest,NullComputationParent)93 TEST_F(HloVerifierTest, NullComputationParent) {
94 HloComputation::Builder builder(TestName());
95 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
96 HloInstruction* param = builder.AddInstruction(
97 HloInstruction::CreateParameter(0, scalar_shape, "param"));
98 builder.AddInstruction(
99 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
100 auto module = CreateUnverifiedModule();
101 HloComputation* computation = module->AddEntryComputation(builder.Build());
102
103 TF_ASSERT_OK(verifier().Run(module.get()).status());
104
105 computation->set_parent(nullptr);
106
107 auto status = verifier().Run(module.get()).status();
108 ASSERT_FALSE(status.ok());
109 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
110 }
111
TEST_F(HloVerifierTest,DifferentOperandParents)112 TEST_F(HloVerifierTest, DifferentOperandParents) {
113 HloComputation::Builder builder(TestName());
114 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
115 HloInstruction* param = builder.AddInstruction(
116 HloInstruction::CreateParameter(0, scalar_shape, "param"));
117 HloInstruction* negate = builder.AddInstruction(
118 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
119 auto module = CreateUnverifiedModule();
120 module->AddEntryComputation(builder.Build());
121
122 HloComputation::Builder emb_builder(TestName());
123 HloInstruction* emb_param = emb_builder.AddInstruction(
124 HloInstruction::CreateParameter(0, scalar_shape, "param"));
125 module->AddEmbeddedComputation(emb_builder.Build());
126
127 TF_ASSERT_OK(verifier().Run(module.get()).status());
128 TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
129
130 auto status = verifier().Run(module.get()).status();
131 ASSERT_FALSE(status.ok());
132 EXPECT_THAT(status.error_message(),
133 HasSubstr("is in a different computation"));
134 }
135
TEST_F(HloVerifierTest,ResetsShapeVerifierState)136 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
137 HloComputation::Builder builder(TestName());
138 Shape s1 = ShapeUtil::MakeShape(F32, {1});
139 Shape s2 = ShapeUtil::MakeShape(F32, {2});
140
141 HloInstruction* param =
142 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
143
144 // Create an add instruction with the incorrect shape.
145 HloInstruction* add = builder.AddInstruction(
146 HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
147
148 // In order to trigger the bug we're checking for, the instruction with the
149 // bad shape can't be the root of the computation.
150 builder.AddInstruction(
151 HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
152
153 auto module = CreateUnverifiedModule();
154 module->AddEntryComputation(builder.Build());
155
156 // Run the verifier twice. It should fail both times, because it shouldn't
157 // carry state in its DFS visitor between runs.
158 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
159 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
160 }
161
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)162 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
163 const char* const hlo_string = R"(
164 HloModule Module
165
166 callme {
167 ROOT param = (s32[], f32[4]) parameter(0)
168 }
169
170 ENTRY entry {
171 p0 = (f32[4], s32[]) parameter(0)
172 ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
173 }
174 )";
175 TF_ASSERT_OK_AND_ASSIGN(auto module,
176 ParseAndReturnUnverifiedModule(hlo_string));
177
178 auto status = verifier().Run(module.get()).status();
179 ASSERT_FALSE(status.ok());
180 EXPECT_THAT(status.error_message(),
181 HasSubstr("shape does not match parameter"));
182 }
183
TEST_F(HloVerifierTest,CheckCallThreadMismatch)184 TEST_F(HloVerifierTest, CheckCallThreadMismatch) {
185 constexpr absl::string_view hlo = R"(
186 HloModule Module
187
188 callme {
189 ROOT param = (s32[], f32[4]) parameter(0)
190 }, execution_thread="parallel_thread"
191
192 ENTRY entry {
193 p0 = (s32[], f32[4]) parameter(0)
194 ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
195 }
196 )";
197 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
198
199 auto status = verifier().Run(module.get()).status();
200 ASSERT_FALSE(status.ok());
201 EXPECT_THAT(status.error_message(),
202 HasSubstr("expects parent computation thread name same as called "
203 "computation's thread name"));
204 }
205
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)206 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
207 const char* const hlo_string = R"(
208 HloModule Module
209
210 true_branch {
211 tparam = (s32[], f32[4]) parameter(0)
212 ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
213 }
214
215 false_branch {
216 fparam = (s32[], f32[4]) parameter(0)
217 ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
218 }
219
220 ENTRY entry {
221 p0 = (f32[4], s32[]) parameter(0)
222 constant = pred[] constant(true)
223 ROOT conditional = f32[4] conditional(constant, p0, p0),
224 true_computation=true_branch, false_computation=false_branch
225 }
226 )";
227 TF_ASSERT_OK_AND_ASSIGN(auto module,
228 ParseAndReturnUnverifiedModule(hlo_string));
229
230 auto status = verifier().Run(module.get()).status();
231 ASSERT_FALSE(status.ok());
232 EXPECT_THAT(status.error_message(),
233 HasSubstr("shape does not match parameter"));
234 }
235
TEST_F(HloVerifierTest,CheckConditionalBranchIndexOperandShape)236 TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
237 const char* const hlo_string = R"(
238 HloModule Module
239
240 branch0 {
241 tparam = f32[4] parameter(0)
242 ROOT tgte1 = f32[4] ceil(tparam)
243 }
244
245 branch1 {
246 fparam = f32[4] parameter(0)
247 ROOT fgte1 = f32[4] floor(fparam)
248 }
249
250 branch2 {
251 sparam = f32[4] parameter(0)
252 ROOT sgte1 = f32[4] ceil(sparam)
253 }
254
255 ENTRY entry {
256 p0 = f32[4] parameter(0)
257 b0 = s32[] parameter(1)
258 ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
259 branch_computations={branch0, branch1, branch2}
260 }
261 )";
262 TF_ASSERT_OK_AND_ASSIGN(auto module,
263 ParseAndReturnUnverifiedModule(hlo_string));
264 auto status = verifier().Run(module.get()).status();
265
266 HloInstruction* condition = FindInstruction(module.get(), "b0");
267 *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
268 status = verifier().Run(module.get()).status();
269 ASSERT_FALSE(status.ok());
270 EXPECT_THAT(
271 status.error_message(),
272 HasSubstr(
273 "first operand of indexed conditional must be a scalar of S32"));
274
275 *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
276 status = verifier().Run(module.get()).status();
277 ASSERT_FALSE(status.ok());
278 EXPECT_THAT(status.error_message(),
279 HasSubstr("first operand of conditional must be a scalar"));
280 }
281
TEST_F(HloVerifierTest,CheckConditionalBranchThread)282 TEST_F(HloVerifierTest, CheckConditionalBranchThread) {
283 const char* const hlo_string = R"(
284 HloModule Module
285
286 branch0 {
287 tparam = f32[4] parameter(0)
288 ROOT tgte1 = f32[4] ceil(tparam)
289 }
290
291 branch1 {
292 fparam = f32[4] parameter(0)
293 ROOT fgte1 = f32[4] floor(fparam)
294 }, execution_thread="parallel_thread"
295
296 branch2 {
297 sparam = f32[4] parameter(0)
298 ROOT sgte1 = f32[4] ceil(sparam)
299 }
300
301 ENTRY entry {
302 p0 = f32[4] parameter(0)
303 b0 = s32[] parameter(1)
304 ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
305 branch_computations={branch0, branch1, branch2}
306 }
307 )";
308 TF_ASSERT_OK_AND_ASSIGN(auto module,
309 ParseAndReturnUnverifiedModule(hlo_string));
310 auto status = verifier().Run(module.get()).status();
311 EXPECT_THAT(status.error_message(),
312 HasSubstr("expects parent computation thread name same as called "
313 "computation's thread name"));
314 }
315
TEST_F(HloVerifierTest,CheckConditionalBranchContainsAsyncThread)316 TEST_F(HloVerifierTest, CheckConditionalBranchContainsAsyncThread) {
317 const char* const hlo_string = R"(
318 HloModule Module
319
320 branch0 {
321 tparam = f32[4] parameter(0)
322 ROOT tgte1 = f32[4] ceil(tparam)
323 }
324
325 branch1 {
326 fparam = f32[4] parameter(0)
327 %async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo"
328 ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
329 }
330
331 branch2 {
332 sparam = f32[4] parameter(0)
333 ROOT sgte1 = f32[4] ceil(sparam)
334 }
335
336 ENTRY entry {
337 p0 = f32[4] parameter(0)
338 b0 = s32[] parameter(1)
339 ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
340 branch_computations={branch0, branch1, branch2}
341 }
342 )";
343 TF_ASSERT_OK_AND_ASSIGN(auto module,
344 ParseAndReturnUnverifiedModule(hlo_string));
345 TF_ASSERT_OK(verifier().Run(module.get()).status());
346 }
347
TEST_F(HloVerifierTest,RngOpnd0NotScalar)348 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
349 const char* const hlo_string = R"(
350 HloModule Module
351
352 ENTRY RngOpnd0NotScalar {
353 constant.0 = f32[] constant(0)
354 constant.1 = f16[2] constant({1, 3})
355 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
356 distribution=rng_uniform
357 }
358 )";
359 TF_ASSERT_OK_AND_ASSIGN(auto module,
360 ParseAndReturnUnverifiedModule(hlo_string));
361
362 auto status = verifier().Run(module.get()).status();
363 ASSERT_FALSE(status.ok());
364 EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
365 }
366
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)367 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
368 const char* const hlo_string = R"(
369 HloModule Module
370
371 ENTRY RngOperandElementTypesNotMatch {
372 constant.0 = f32[] constant(0)
373 constant.1 = f16[] constant(1)
374 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
375 distribution=rng_normal
376 }
377 )";
378 TF_ASSERT_OK_AND_ASSIGN(auto module,
379 ParseAndReturnUnverifiedModule(hlo_string));
380
381 auto status = verifier().Run(module.get()).status();
382 ASSERT_FALSE(status.ok());
383 EXPECT_THAT(status.error_message(),
384 HasSubstr("Expected compatible element types"));
385 }
386
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)387 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
388 const char* const hlo_string = R"(
389 HloModule Module
390
391 ENTRY RngResultElementTypeNotMatch {
392 constant.0 = f32[] constant(0)
393 constant.1 = f32[] constant(1)
394 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
395 distribution=rng_normal
396 }
397 )";
398 TF_ASSERT_OK_AND_ASSIGN(auto module,
399 ParseAndReturnUnverifiedModule(hlo_string));
400
401 auto status = verifier().Run(module.get()).status();
402 ASSERT_FALSE(status.ok());
403 EXPECT_THAT(status.error_message(),
404 HasSubstr("Expected compatible element types"));
405 }
406
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)407 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
408 const char* const hlo_string = R"(
409 HloModule Module
410
411 ENTRY RngResultElementTypeNotMatch {
412 constant.0 = f32[] constant(0)
413 constant.1 = f32[] constant(1)
414 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
415 distribution=rng_normal
416 }
417 )";
418 TF_ASSERT_OK_AND_ASSIGN(auto module,
419 ParseAndReturnVerifiedModule(hlo_string));
420
421 auto status = verifier().Run(module.get()).status();
422 ASSERT_TRUE(status.ok());
423 }
424
TEST_F(HloVerifierTest,RngElementTypeNotSupported)425 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
426 const char* const hlo_string = R"(
427 HloModule Module
428
429 ENTRY RngElementTypeNotSupported {
430 constant.0 = s32[] constant(0)
431 constant.1 = s32[] constant(1)
432 ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
433 distribution=rng_normal
434 }
435 )";
436 TF_ASSERT_OK_AND_ASSIGN(auto module,
437 ParseAndReturnUnverifiedModule(hlo_string));
438
439 auto status = verifier().Run(module.get()).status();
440 ASSERT_FALSE(status.ok());
441 EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
442 }
443
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)444 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
445 // This testcase can't be written using textual HLO, because it doesn't parse
446 // negative interior padding. That's probably a feature. :)
447 HloComputation::Builder builder(TestName());
448 HloInstruction* param =
449 builder.AddInstruction(HloInstruction::CreateParameter(
450 0, ShapeUtil::MakeShape(F32, {100}), "param"));
451 PaddingConfig padding_config;
452 padding_config.add_dimensions()->set_interior_padding(-1);
453 builder.AddInstruction(HloInstruction::CreatePad(
454 ShapeUtil::MakeShape(F32, {100}), param,
455 builder.AddInstruction(
456 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
457 padding_config));
458
459 auto module = CreateUnverifiedModule();
460 module->AddEntryComputation(builder.Build());
461
462 auto status = verifier().Run(module.get()).status();
463 ASSERT_FALSE(status.ok());
464 EXPECT_THAT(status.error_message(),
465 HasSubstr("Interior padding cannot be negative"));
466 }
467
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)468 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
469 // This testcase can't be written using textual HLO, because it doesn't parse
470 // negative interior padding. That's probably a feature. :)
471 HloComputation::Builder builder(TestName());
472 HloInstruction* param =
473 builder.AddInstruction(HloInstruction::CreateParameter(
474 0, ShapeUtil::MakeShape(F32, {100}), "param"));
475 PaddingConfig padding_config;
476 padding_config.add_dimensions()->set_interior_padding(-1);
477 builder.AddInstruction(HloInstruction::CreatePad(
478 ShapeUtil::MakeShape(F32, {100}), param,
479 builder.AddInstruction(
480 HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
481 padding_config));
482
483 auto module = CreateUnverifiedModule();
484 module->AddEntryComputation(builder.Build());
485
486 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
487 HasSubstr("Interior padding cannot be negative"));
488 }
489
TEST_F(HloVerifierTest,DotMixedPrecisionAllowed)490 TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
491 static const char* const kDotHloString = R"(
492 HloModule module
493 ENTRY entry_computation {
494 a = f32[2,10] parameter(0)
495 b = bf16[10,2] parameter(1)
496 ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
497 })";
498 TF_ASSERT_OK_AND_ASSIGN(auto module,
499 ParseAndReturnVerifiedModule(kDotHloString));
500
501 auto status = verifier().Run(module.get()).status();
502 EXPECT_TRUE(status.ok()) << status;
503 }
504
505 // Simple module containing a convolution as the root.
506 static const char* const kConvHloString = R"(
507 HloModule module
508 ENTRY entry_computation {
509 param0 = f16[128,128,56,56] parameter(0)
510 param1 = f16[3,3,128,128] parameter(1)
511 zero_f16 = f16[] constant(0)
512 ROOT conv = f16[128,128,28,28] convolution(param0, param1),
513 window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
514 })";
515
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)516 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
517 TF_ASSERT_OK_AND_ASSIGN(auto module,
518 ParseAndReturnUnverifiedModule(kConvHloString));
519 auto* conv = module->entry_computation()->root_instruction();
520 Window w = conv->window();
521 w.mutable_dimensions(0)->set_window_dilation(-1);
522 conv->set_window(w);
523
524 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
525 HasSubstr("non-positive window dilation factor"));
526 }
527
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)528 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
529 TF_ASSERT_OK_AND_ASSIGN(auto module,
530 ParseAndReturnUnverifiedModule(kConvHloString));
531 auto* conv = module->entry_computation()->root_instruction();
532 Window w = conv->window();
533 w.mutable_dimensions(0)->set_base_dilation(-1);
534 conv->set_window(w);
535
536 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
537 HasSubstr("non-positive base area dilation factor"));
538 }
539
540 static const char* const kAddWithLayoutChangeHlo = R"(
541 HloModule AddWithLayoutChange
542 ENTRY AddWithLayoutChange {
543 par0 = f32[3,4]{1,0} parameter(0)
544 par1 = f32[3,4]{0,1} parameter(1)
545 ROOT add0 = f32[3,4]{1,0} add(par0,par1)
546 }
547 )";
548
TEST_F(HloVerifierTest,AddWithLayoutChange)549 TEST_F(HloVerifierTest, AddWithLayoutChange) {
550 TF_ASSERT_OK_AND_ASSIGN(
551 auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
552 auto status = verifier().Run(module.get()).status();
553 ASSERT_TRUE(status.ok());
554 }
555
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)556 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
557 const char* const kScalarIndexDynamicSlice = R"(
558 HloModule DynamicSlice_module
559
560 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
561 %original_parameter = s32[2,2,258] parameter(0)
562 %constant = s32[] constant(0)
563 %start_index = s32[] parameter(1)
564 ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
565 }
566 )";
567
568 HloModuleConfig config;
569 DebugOptions debug_options = config.debug_options();
570 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
571 config.set_debug_options(debug_options);
572
573 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
574 kScalarIndexDynamicSlice, config));
575 auto status = verifier().Run(module.get()).status();
576 ASSERT_TRUE(status.ok());
577 }
578
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)579 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
580 const char* const kScalarIndexDynamicSlice = R"(
581 HloModule DynamicUpdateSlice_module
582
583 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
584 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
585 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
586 %start_index.0 = s32[] parameter(2)
587 %start_index.1 = s32[] parameter(3)
588 %start_index.2 = s32[] parameter(4)
589 %start_index.3 = s32[] parameter(5)
590 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
591 }
592 )";
593
594 HloModuleConfig config;
595 DebugOptions debug_options = config.debug_options();
596 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
597 config.set_debug_options(debug_options);
598
599 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
600 kScalarIndexDynamicSlice, config));
601 auto status = verifier().Run(module.get()).status();
602 ASSERT_TRUE(status.ok());
603 }
604
TEST_F(HloVerifierTestAllowMixedPrecision,DynamicUpdateSliceMixedPrecision)605 TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) {
606 const char* const kDynamicUpdateSliceMixedPrecision = R"(
607 HloModule kDynamicUpdateSliceMixedPrecision
608
609 ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] {
610 %parameter.0 = f32[32,511,2048] parameter(0)
611 %parameter.1 = bf16[32,511,512] parameter(1)
612 %parameter.2 = s32[] parameter(2)
613 %parameter.3 = s32[] parameter(3)
614 %parameter.4 = s32[] parameter(4)
615 ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4)
616 }
617 )";
618 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
619 kDynamicUpdateSliceMixedPrecision));
620 auto status = verifier().Run(module.get()).status();
621 ASSERT_FALSE(status.ok());
622 EXPECT_THAT(status.error_message(),
623 HasSubstr("Expected instruction to have shape equal to "
624 "f32[32,511,2048], actual shape is bf16[32,511,2048]"));
625 }
626
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)627 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
628 TF_ASSERT_OK_AND_ASSIGN(
629 auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
630 auto status = verifier().Run(module.get()).status();
631 ASSERT_FALSE(status.ok());
632 EXPECT_THAT(status.error_message(),
633 HasSubstr("Instruction shouldn't change layouts"));
634 }
635
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)636 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
637 const char* const kSliceWithLayoutChangeHlo = R"(
638 HloModule SliceWithLayoutChange
639 ENTRY SliceWithLayoutChange {
640 par0 = f32[4,5]{0,1} parameter(0)
641 par1 = s32[] parameter(1)
642 par2 = s32[] parameter(2)
643 ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
644 dynamic_slice_sizes={3,4}
645 }
646 )";
647 TF_ASSERT_OK_AND_ASSIGN(
648 auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
649 auto status = verifier().Run(module.get()).status();
650 ASSERT_FALSE(status.ok());
651 EXPECT_THAT(status.error_message(),
652 HasSubstr("Instruction shouldn't change layouts"));
653 }
654
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)655 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
656 const char* const kConcatWithLayoutChangeHlo = R"(
657 HloModule ConcatWithLayoutChange
658 ENTRY ConcatWithLayoutChange {
659 par0 = f32[3,5]{0,1} parameter(0)
660 par1 = f32[3,3]{1,0} parameter(1)
661 ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
662 dimensions={1}
663 }
664 )";
665 TF_ASSERT_OK_AND_ASSIGN(
666 auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
667 auto status = verifier().Run(module.get()).status();
668 ASSERT_FALSE(status.ok());
669 EXPECT_THAT(status.error_message(),
670 HasSubstr("Instruction shouldn't change layouts"));
671 }
672
TEST_F(HloVerifierTestLayoutSensitive,BitcastNeedsSameNumberOfElements)673 TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
674 const char* const hlo_string = R"(
675 HloModule Module
676
677 ENTRY BitcastNeedsToBeNoOp {
678 constant.0 = f32[2] constant({0.0, 0.0})
679 ROOT bitcast = f32[3] bitcast(constant.0)
680 }
681 )";
682 TF_ASSERT_OK_AND_ASSIGN(auto module,
683 ParseAndReturnUnverifiedModule(hlo_string));
684
685 auto status = verifier().Run(module.get()).status();
686 ASSERT_FALSE(status.ok());
687 EXPECT_THAT(status.error_message(),
688 HasSubstr("Bitcast cannot have different shape sizes of output "
689 "(12) and operand (8)"));
690 }
691
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)692 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
693 const char* const hlo_string = R"(
694 HloModule Module
695
696 ENTRY SelectMixedPrecisionNotAllowed {
697 p0 = pred[32] parameter(0)
698 p1 = f32[32] parameter(1)
699 p2 = bf16[32] parameter(2)
700 ROOT select = f32[32] select(p0, p1, p2)
701 }
702 )";
703 TF_ASSERT_OK_AND_ASSIGN(auto module,
704 ParseAndReturnUnverifiedModule(hlo_string));
705
706 auto status = verifier().Run(module.get()).status();
707 ASSERT_FALSE(status.ok());
708 EXPECT_THAT(status.error_message(),
709 HasSubstr("Seen floating point types of different precisions"));
710 }
711
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)712 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
713 const char* const hlo_string = R"(
714 HloModule Module
715
716 ENTRY SelectMixedPrecisionAllowed {
717 p0 = pred[32] parameter(0)
718 p1 = f32[32] parameter(1)
719 p2 = bf16[32] parameter(2)
720 ROOT select = f32[32] select(p0, p1, p2)
721 }
722 )";
723 TF_ASSERT_OK_AND_ASSIGN(auto module,
724 ParseAndReturnVerifiedModule(hlo_string));
725
726 auto status = verifier().Run(module.get()).status();
727 ASSERT_TRUE(status.ok());
728 }
729
TEST_F(HloVerifierTest,SelectTupleNotAllowed)730 TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
731 const char* const hlo_string = R"(
732 HloModule Module
733
734 ENTRY SelectWithTuple {
735 p0 = (f32[], f32[]) parameter(0)
736 p1 = (f32[], f32[]) parameter(1)
737 p2 = pred[] parameter(2)
738 ROOT select = (f32[], f32[]) select(p2, p0, p1)
739 }
740 )";
741 TF_ASSERT_OK_AND_ASSIGN(auto module,
742 ParseAndReturnUnverifiedModule(hlo_string));
743
744 auto status = verifier().Run(module.get()).status();
745 ASSERT_FALSE(status.ok());
746 EXPECT_THAT(status.error_message(),
747 HasSubstr("Expected array argument for select"));
748 }
749
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDone)750 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
751 const char* const hlo_string = R"(
752 HloModule Module
753
754 ENTRY CopyStartAndCopyDone {
755 p0 = f32[2,3]{1,0:S(1)} parameter(0)
756 copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
757 ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
758 }
759 )";
760 TF_ASSERT_OK_AND_ASSIGN(auto module,
761 ParseAndReturnVerifiedModule(hlo_string));
762
763 auto status = verifier().Run(module.get()).status();
764 ASSERT_TRUE(status.ok());
765 }
766
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDoneWrongLayout)767 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
768 const char* const hlo_string = R"(
769 HloModule Module
770
771 ENTRY CopyStartAndCopyDone {
772 p0 = f32[2,3]{1,0:S(1)} parameter(0)
773 copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
774 ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
775 }
776 )";
777 TF_ASSERT_OK_AND_ASSIGN(auto module,
778 ParseAndReturnUnverifiedModule(hlo_string));
779
780 auto status = verifier().Run(module.get()).status();
781 ASSERT_FALSE(status.ok());
782 EXPECT_THAT(status.error_message(),
783 HasSubstr("Expected instruction to have shape equal to"));
784 }
785
TEST_F(HloVerifierTest,CopyStartAndCopyDoneWrongType)786 TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
787 const char* const hlo_string = R"(
788 HloModule Module
789
790 ENTRY CopyStartAndCopyDone {
791 p0 = f32[2,3] parameter(0)
792 copy-start = f32[2,3] copy-start(p0)
793 ROOT copy-done = f32[2,3] copy-done(copy-start)
794 }
795 )";
796 TF_ASSERT_OK_AND_ASSIGN(auto module,
797 ParseAndReturnUnverifiedModule(hlo_string));
798
799 auto status = verifier().Run(module.get()).status();
800 ASSERT_FALSE(status.ok());
801 EXPECT_THAT(status.error_message(),
802 HasSubstr("Expected instruction to have shape equal to "
803 "(f32[2,3], f32[2,3], u32[])"));
804 }
805
TEST_F(HloVerifierTest,CopyStartMultipleCopyDone)806 TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
807 const char* const hlo_string = R"(
808 HloModule Module
809
810 ENTRY CopyStartAndCopyDone {
811 p0 = f32[2,3] parameter(0)
812 copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
813 copy-done.1 = f32[2,3] copy-done(copy-start)
814 copy-done.2 = f32[2,3] copy-done(copy-start)
815 ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
816 }
817 )";
818 TF_ASSERT_OK_AND_ASSIGN(auto module,
819 ParseAndReturnUnverifiedModule(hlo_string));
820
821 auto status = verifier().Run(module.get()).status();
822 ASSERT_FALSE(status.ok());
823 EXPECT_THAT(
824 status.error_message(),
825 HasSubstr("copy-start instruction requires one consumer, found 2"));
826 }
827
TEST_F(HloVerifierTest,CopyDoneNoCopyStart)828 TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
829 const char* const hlo_string = R"(
830 HloModule Module
831
832 ENTRY CopyStartAndCopyDone {
833 p0 = f32[2,3] parameter(0)
834 p1 = u32[] parameter(1)
835 tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
836 ROOT copy-done = f32[2,3] copy-done(tuple)
837 }
838 )";
839 TF_ASSERT_OK_AND_ASSIGN(auto module,
840 ParseAndReturnUnverifiedModule(hlo_string));
841
842 auto status = verifier().Run(module.get()).status();
843 ASSERT_FALSE(status.ok());
844 EXPECT_THAT(status.error_message(),
845 HasSubstr("The operand of a copy-done instruction needs to be "
846 "copy-start, found tuple"));
847 }
848
TEST_F(HloVerifierTestLayoutSensitive,AsyncStartAndAsyncDone)849 TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncDone) {
850 const char* const hlo_string = R"(
851 HloModule Module
852
853 ENTRY AsyncStartAndAsyncDone {
854 p0 = f32[2,3]{1,0:S(1)} parameter(0)
855 async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo"
856 ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-start), custom_call_target="foo"
857 }
858 )";
859 TF_ASSERT_OK_AND_ASSIGN(auto module,
860 ParseAndReturnVerifiedModule(hlo_string));
861
862 auto status = verifier().Run(module.get()).status();
863 ASSERT_TRUE(status.ok());
864 }
865
TEST_F(HloVerifierTestLayoutSensitive,AsyncStartAndAsyncUpdateAndAsyncDone)866 TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncUpdateAndAsyncDone) {
867 const char* const hlo_string = R"(
868 HloModule Module
869
870 ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
871 p0 = f32[2,3]{1,0:S(1)} parameter(0)
872 async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo"
873 async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo"
874 async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), custom_call_target="foo"
875 ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), custom_call_target="foo"
876 }
877 )";
878 TF_ASSERT_OK_AND_ASSIGN(auto module,
879 ParseAndReturnVerifiedModule(hlo_string));
880
881 auto status = verifier().Run(module.get()).status();
882 ASSERT_TRUE(status.ok());
883 }
884
TEST_F(HloVerifierTestLayoutSensitive,AsyncStartAndAsyncUpdateAndAsyncDoneWithThreadName)885 TEST_F(HloVerifierTestLayoutSensitive,
886 AsyncStartAndAsyncUpdateAndAsyncDoneWithThreadName) {
887 const char* const hlo_string = R"(
888 HloModule Module
889
890 ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
891 p0 = f32[2,3]{1,0:S(1)} parameter(0)
892 async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_execution_thread="parallel_thread", custom_call_target="foo"
893 async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
894 async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_execution_thread="parallel_thread", custom_call_target="foo"
895 ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_execution_thread="parallel_thread", custom_call_target="foo"
896 }
897 )";
898 TF_ASSERT_OK_AND_ASSIGN(auto module,
899 ParseAndReturnVerifiedModule(hlo_string));
900
901 auto status = verifier().Run(module.get()).status();
902 ASSERT_TRUE(status.ok());
903 }
904
TEST_F(HloVerifierTest,AsyncStartAndAsyncDoneWrongType)905 TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongType) {
906 const char* const hlo_string = R"(
907 HloModule Module
908
909 ENTRY AsyncStartAndAsyncDone {
910 p0 = f32[2,3] parameter(0)
911 async-start = ((f32[2,3]), f32[3,2], u32[]) custom-call-start(p0), custom_call_target="foo"
912 ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
913 }
914 )";
915 TF_ASSERT_OK_AND_ASSIGN(auto module,
916 ParseAndReturnUnverifiedModule(hlo_string));
917
918 auto status = verifier().Run(module.get()).status();
919 ASSERT_FALSE(status.ok());
920 EXPECT_THAT(status.error_message(),
921 HasSubstr("async-done expects the async shape at index {1} to "
922 "match the async computation root shape"));
923 }
924
TEST_F(HloVerifierTest,AsyncStartAndAsyncDoneWrongThreadName)925 TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongThreadName) {
926 const char* const hlo_string = R"(
927 HloModule Module
928
929 ENTRY AsyncStartAndAsyncDone {
930 p0 = f32[2,3] parameter(0)
931 async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), async_execution_thread="parallel_thread", custom_call_target="foo"
932 ROOT async-done = f32[2,3] custom-call-done(async-start), async_execution_thread="main_thread", custom_call_target="bar"
933 }
934 )";
935 TF_ASSERT_OK_AND_ASSIGN(auto module,
936 ParseAndReturnUnverifiedModule(hlo_string));
937
938 auto status = verifier().Run(module.get()).status();
939 ASSERT_FALSE(status.ok());
940 EXPECT_THAT(status.error_message(),
941 HasSubstr("thread name (main_thread vs parallel_thread)."));
942 }
943
TEST_F(HloVerifierTest,AsyncStartAndAsyncDoneWrongAttr)944 TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongAttr) {
945 const char* const hlo_string = R"(
946 HloModule Module
947
948 ENTRY AsyncStartAndAsyncDone {
949 p0 = f32[2,3] parameter(0)
950 async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
951 ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="bar"
952 }
953 )";
954 TF_ASSERT_OK_AND_ASSIGN(auto module,
955 ParseAndReturnUnverifiedModule(hlo_string));
956
957 auto status = verifier().Run(module.get()).status();
958 ASSERT_FALSE(status.ok());
959 EXPECT_THAT(status.error_message(),
960 HasSubstr("async-done expects its wrapped async computation to "
961 "be identical to its operand's"));
962 }
963
TEST_F(HloVerifierTest,AsyncStartMultipleAsyncDone)964 TEST_F(HloVerifierTest, AsyncStartMultipleAsyncDone) {
965 const char* const hlo_string = R"(
966 HloModule Module
967
968 ENTRY AsyncStartAndAsyncDone {
969 p0 = f32[2,3] parameter(0)
970 async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
971 async-done.1 = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
972 async-done.2 = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
973 ROOT tuple = (f32[2,3], f32[2,3]) tuple(async-done.1, async-done.2)
974 }
975 )";
976 TF_ASSERT_OK_AND_ASSIGN(auto module,
977 ParseAndReturnUnverifiedModule(hlo_string));
978
979 auto status = verifier().Run(module.get()).status();
980 ASSERT_FALSE(status.ok());
981 EXPECT_THAT(
982 status.error_message(),
983 HasSubstr("async-start instruction requires one consumer, found 2"));
984 }
985
TEST_F(HloVerifierTest,AsyncStartNoAsyncDone)986 TEST_F(HloVerifierTest, AsyncStartNoAsyncDone) {
987 const char* const hlo_string = R"(
988 HloModule Module
989
990 ENTRY AsyncStartAndAsyncDone {
991 p0 = f32[2,3] parameter(0)
992 ROOT async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
993 }
994 )";
995 TF_ASSERT_OK_AND_ASSIGN(auto module,
996 ParseAndReturnUnverifiedModule(hlo_string));
997
998 auto status = verifier().Run(module.get()).status();
999 ASSERT_FALSE(status.ok());
1000 EXPECT_THAT(
1001 status.error_message(),
1002 HasSubstr("async-start instruction requires one consumer, found 0"));
1003 }
1004
TEST_F(HloVerifierTest,AsyncStartAndAsyncUpdateNoAsyncDone)1005 TEST_F(HloVerifierTest, AsyncStartAndAsyncUpdateNoAsyncDone) {
1006 const char* const hlo_string = R"(
1007 HloModule Module
1008
1009 ENTRY AsyncStartAndAsyncDone {
1010 p0 = f32[2,3] parameter(0)
1011 async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
1012 ROOT async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo"
1013 }
1014 )";
1015 TF_ASSERT_OK_AND_ASSIGN(auto module,
1016 ParseAndReturnUnverifiedModule(hlo_string));
1017
1018 auto status = verifier().Run(module.get()).status();
1019 ASSERT_FALSE(status.ok());
1020 EXPECT_THAT(
1021 status.error_message(),
1022 HasSubstr("async-update instruction requires one consumer, found 0"));
1023 }
1024
TEST_F(HloVerifierTest,AsyncDoneNoAsyncStart)1025 TEST_F(HloVerifierTest, AsyncDoneNoAsyncStart) {
1026 const char* const hlo_string = R"(
1027 HloModule Module
1028
1029 ENTRY AsyncStartAndAsyncDone {
1030 p0 = f32[2,3] parameter(0)
1031 p1 = u32[] parameter(1)
1032 tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1)
1033 ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo"
1034 }
1035 )";
1036 TF_ASSERT_OK_AND_ASSIGN(auto module,
1037 ParseAndReturnUnverifiedModule(hlo_string));
1038
1039 auto status = verifier().Run(module.get()).status();
1040 ASSERT_FALSE(status.ok());
1041 EXPECT_THAT(status.error_message(),
1042 HasSubstr("The operand of a async-done instruction needs to be "
1043 "async-start or async-update, found tuple"));
1044 }
1045
TEST_F(HloVerifierTest,AsyncUpdateAndAsyncDoneNoAsyncStart)1046 TEST_F(HloVerifierTest, AsyncUpdateAndAsyncDoneNoAsyncStart) {
1047 const char* const hlo_string = R"(
1048 HloModule Module
1049
1050 ENTRY AsyncStartAndAsyncDone {
1051 p0 = f32[2,3] parameter(0)
1052 p1 = u32[] parameter(1)
1053 tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1)
1054 async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(tuple), custom_call_target="foo"
1055 ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo"
1056 }
1057 )";
1058 TF_ASSERT_OK_AND_ASSIGN(auto module,
1059 ParseAndReturnUnverifiedModule(hlo_string));
1060
1061 auto status = verifier().Run(module.get()).status();
1062 ASSERT_FALSE(status.ok());
1063 EXPECT_THAT(status.error_message(),
1064 HasSubstr("The operand of a async-update instruction needs to be "
1065 "async-start or async-update, found tuple"));
1066 }
1067
TEST_F(HloVerifierTest,AsyncOpComputationParamWrongType)1068 TEST_F(HloVerifierTest, AsyncOpComputationParamWrongType) {
1069 const char* const hlo_string = R"(
1070 HloModule Module
1071
1072 async_computation {
1073 p = f32[2,3] parameter(0)
1074 ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1075 }
1076
1077 ENTRY AsyncStartAndAsyncDone {
1078 p0 = f32[2,3] parameter(0)
1079 async-start = ((f32[3,2]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1080 ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1081 }
1082 )";
1083 TF_ASSERT_OK_AND_ASSIGN(auto module,
1084 ParseAndReturnUnverifiedModule(hlo_string));
1085
1086 auto status = verifier().Run(module.get()).status();
1087 ASSERT_FALSE(status.ok());
1088 EXPECT_THAT(status.error_message(),
1089 HasSubstr("async-start expects the async shape at index {0} to "
1090 "match async computation parameter shape"));
1091 }
1092
TEST_F(HloVerifierTest,AsyncOpComputationRootWrongType)1093 TEST_F(HloVerifierTest, AsyncOpComputationRootWrongType) {
1094 const char* const hlo_string = R"(
1095 HloModule Module
1096
1097 async_computation {
1098 p = f32[2,3] parameter(0)
1099 ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1100 }
1101
1102 ENTRY AsyncStartAndAsyncDone {
1103 p0 = f32[2,3] parameter(0)
1104 async-start = ((f32[2,3]), f32[2,3], u32[]) async-start(p0), calls=async_computation
1105 ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1106 }
1107 )";
1108 TF_ASSERT_OK_AND_ASSIGN(auto module,
1109 ParseAndReturnUnverifiedModule(hlo_string));
1110
1111 auto status = verifier().Run(module.get()).status();
1112 ASSERT_FALSE(status.ok());
1113 EXPECT_THAT(status.error_message(),
1114 HasSubstr("async-start expects the async shape at index {1} to "
1115 "match the async computation root shape"));
1116 }
1117
TEST_F(HloVerifierTest,AsyncOpTupleWrongType)1118 TEST_F(HloVerifierTest, AsyncOpTupleWrongType) {
1119 const char* const hlo_string = R"(
1120 HloModule Module
1121
1122 async_computation {
1123 p = f32[2,3] parameter(0)
1124 ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1125 }
1126
1127 ENTRY AsyncStartAndAsyncDone {
1128 p0 = f32[2,3] parameter(0)
1129 async-start = ((f32[2,3])) async-start(p0), calls=async_computation
1130 ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1131 }
1132 )";
1133 TF_ASSERT_OK_AND_ASSIGN(auto module,
1134 ParseAndReturnUnverifiedModule(hlo_string));
1135
1136 auto status = verifier().Run(module.get()).status();
1137 ASSERT_FALSE(status.ok());
1138 EXPECT_THAT(status.error_message(),
1139 HasSubstr("async-start expects the async shape to be a tuple of "
1140 "at least two elements"));
1141 }
1142
TEST_F(HloVerifierTest,AsyncStartOperandWrongType)1143 TEST_F(HloVerifierTest, AsyncStartOperandWrongType) {
1144 const char* const hlo_string = R"(
1145 HloModule Module
1146
1147 async_computation {
1148 p = f32[2,3] parameter(0)
1149 ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1150 }
1151
1152 ENTRY AsyncStartAndAsyncDone {
1153 p0 = f32[3,2] parameter(0)
1154 async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1155 ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1156 }
1157 )";
1158 TF_ASSERT_OK_AND_ASSIGN(auto module,
1159 ParseAndReturnUnverifiedModule(hlo_string));
1160
1161 auto status = verifier().Run(module.get()).status();
1162 ASSERT_FALSE(status.ok());
1163 EXPECT_THAT(status.error_message(),
1164 HasSubstr("async-start expects the shape of operand 0 to match "
1165 "the async shape at index {0}"));
1166 }
1167
TEST_F(HloVerifierTest,AsyncDoneOutputWrongType)1168 TEST_F(HloVerifierTest, AsyncDoneOutputWrongType) {
1169 const char* const hlo_string = R"(
1170 HloModule Module
1171
1172 async_computation {
1173 p = f32[2,3] parameter(0)
1174 ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1175 }
1176
1177 ENTRY AsyncStartAndAsyncDone {
1178 p0 = f32[2,3] parameter(0)
1179 async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1180 ROOT async-done = f32[2,3] async-done(async-start), calls=async_computation
1181 }
1182 )";
1183 TF_ASSERT_OK_AND_ASSIGN(auto module,
1184 ParseAndReturnUnverifiedModule(hlo_string));
1185
1186 auto status = verifier().Run(module.get()).status();
1187 ASSERT_FALSE(status.ok());
1188 EXPECT_THAT(status.error_message(),
1189 HasSubstr("async-done expects the shape of output to match the "
1190 "async shape at index {1}"));
1191 }
1192
TEST_F(HloVerifierTest,AsyncUpdateWrongType)1193 TEST_F(HloVerifierTest, AsyncUpdateWrongType) {
1194 const char* const hlo_string = R"(
1195 HloModule Module
1196
1197 async_computation {
1198 p = f32[2,3] parameter(0)
1199 ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1200 }
1201
1202 ENTRY AsyncStartAndAsyncDone {
1203 p0 = f32[2,3] parameter(0)
1204 async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1205 async-update = ((f32[3,2]), f32[3,2], u32[]) async-update(async-start), calls=async_computation
1206 ROOT async-done = f32[3,2] async-done(async-update), calls=async_computation
1207 }
1208 )";
1209 TF_ASSERT_OK_AND_ASSIGN(auto module,
1210 ParseAndReturnUnverifiedModule(hlo_string));
1211
1212 auto status = verifier().Run(module.get()).status();
1213 ASSERT_FALSE(status.ok());
1214 EXPECT_THAT(
1215 status.error_message(),
1216 HasSubstr(
1217 "async-update expects the shape of operand and output to match"));
1218 }
1219
TEST_F(HloVerifierTestLayoutSensitive,AsyncDoneWrongGroupId)1220 TEST_F(HloVerifierTestLayoutSensitive, AsyncDoneWrongGroupId) {
1221 const char* const hlo_string = R"(
1222 HloModule Module
1223
1224 ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
1225 p0 = f32[2,3]{1,0:S(1)} parameter(0)
1226 async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo"
1227 async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_group_id=0, custom_call_target="foo"
1228 async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo"
1229 ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=1, custom_call_target="foo"
1230 }
1231 )";
1232 TF_ASSERT_OK_AND_ASSIGN(auto module,
1233 ParseAndReturnUnverifiedModule(hlo_string));
1234
1235 auto status = verifier().Run(module.get()).status();
1236 ASSERT_FALSE(status.ok());
1237 EXPECT_THAT(status.error_message(),
1238 HasSubstr("async-done expects its operand to have the same group "
1239 "id (1 vs 0)."));
1240 }
1241
TEST_F(HloVerifierTestLayoutSensitive,AsyncUpdateWrongGroupId)1242 TEST_F(HloVerifierTestLayoutSensitive, AsyncUpdateWrongGroupId) {
1243 const char* const hlo_string = R"(
1244 HloModule Module
1245
1246 ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
1247 p0 = f32[2,3]{1,0:S(1)} parameter(0)
1248 async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo"
1249 async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo"
1250 async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo"
1251 ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=0, custom_call_target="foo"
1252 }
1253 )";
1254 TF_ASSERT_OK_AND_ASSIGN(auto module,
1255 ParseAndReturnUnverifiedModule(hlo_string));
1256
1257 auto status = verifier().Run(module.get()).status();
1258 ASSERT_FALSE(status.ok());
1259 EXPECT_THAT(status.error_message(),
1260 HasSubstr("async-update expects its operand to have the same "
1261 "group id (none vs 0)."));
1262 }
1263
TEST_F(HloVerifierTest,IotaNonArrayResult)1264 TEST_F(HloVerifierTest, IotaNonArrayResult) {
1265 const char* const hlo_string = R"(
1266 HloModule IotaTupleResult
1267
1268 ENTRY kernelEntry {
1269 ROOT iota = () iota(), iota_dimension=24
1270 }
1271 )";
1272
1273 TF_ASSERT_OK_AND_ASSIGN(auto module,
1274 ParseAndReturnUnverifiedModule(hlo_string));
1275
1276 auto status = verifier().Run(module.get()).status();
1277 ASSERT_FALSE(status.ok());
1278 EXPECT_THAT(status.error_message(),
1279 HasSubstr("does not support non-array result"));
1280 }
1281
TEST_F(HloVerifierTest,IotaNegativeDimension)1282 TEST_F(HloVerifierTest, IotaNegativeDimension) {
1283 const char* const hlo_string = R"(
1284 HloModule IotaTupleResult
1285
1286 ENTRY kernelEntry {
1287 ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
1288 }
1289 )";
1290
1291 TF_ASSERT_OK_AND_ASSIGN(auto module,
1292 ParseAndReturnUnverifiedModule(hlo_string));
1293
1294 auto status = verifier().Run(module.get()).status();
1295 ASSERT_FALSE(status.ok());
1296 EXPECT_THAT(status.error_message(), HasSubstr("negative"));
1297 }
1298
TEST_F(HloVerifierTest,IotaPredResultNotAllowed)1299 TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
1300 const char* const hlo_string = R"(
1301 HloModule IotaPredResult
1302
1303 ENTRY kernelEntry {
1304 ROOT iota = pred[128] iota(), iota_dimension=0
1305 }
1306 )";
1307
1308 TF_ASSERT_OK_AND_ASSIGN(auto module,
1309 ParseAndReturnUnverifiedModule(hlo_string));
1310
1311 auto status = verifier().Run(module.get()).status();
1312 ASSERT_FALSE(status.ok());
1313 EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
1314 }
1315
1316 static const char* const kMapOperandComputationMismatchHlo = R"(
1317 HloModule MapOperandComputationMismatch
1318
1319 Computation {
1320 param0 = f32[] parameter(0)
1321 constant = f32[] constant(1)
1322 ROOT add = f32[] add(param0, constant)
1323 }
1324
1325 ENTRY kernelEntry {
1326 param = f64[] parameter(0)
1327 ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
1328 })";
1329
TEST_F(HloVerifierTest,MapOperandComputationMismatch)1330 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
1331 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
1332 kMapOperandComputationMismatchHlo));
1333 auto status = verifier().Run(module.get()).status();
1334 ASSERT_FALSE(status.ok());
1335 EXPECT_THAT(
1336 status.error_message(),
1337 HasSubstr(
1338 "Shape mismatch between to_apply computation parameter and operand"));
1339 }
1340
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)1341 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
1342 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
1343 kMapOperandComputationMismatchHlo));
1344 auto status = verifier().Run(module.get()).status();
1345 ASSERT_TRUE(status.ok());
1346 }
1347
1348 static const char* const kReduceOperandComputationMismatchHlo = R"(
1349 HloModule ReduceOperandComputationMismatch
1350 computation {
1351 x = f32[] parameter(0)
1352 y = f32[] parameter(1)
1353 ROOT add = f32[] add(x, y)
1354 }
1355
1356 ENTRY kernelEntry {
1357 arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
1358 constant = f16[] constant(0)
1359 reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
1360 })";
1361
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)1362 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
1363 TF_ASSERT_OK_AND_ASSIGN(
1364 auto module,
1365 ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
1366 auto status = verifier().Run(module.get()).status();
1367 ASSERT_FALSE(status.ok());
1368 EXPECT_THAT(status.error_message(),
1369 HasSubstr("Expected instruction to have shape equal to f32[64]"));
1370 }
1371
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)1372 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
1373 TF_ASSERT_OK_AND_ASSIGN(
1374 auto module,
1375 ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
1376 auto status = verifier().Run(module.get()).status();
1377 ASSERT_TRUE(status.ok());
1378 }
1379
ReplicaGroupsStr(std::vector<std::vector<int64_t>> replica_groups)1380 std::string ReplicaGroupsStr(std::vector<std::vector<int64_t>> replica_groups) {
1381 std::vector<std::string> replica_group_strs;
1382 replica_group_strs.reserve(replica_groups.size());
1383 for (const auto& g : replica_groups) {
1384 replica_group_strs.push_back(
1385 absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
1386 }
1387 return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
1388 }
1389
ReplicaCount(const std::vector<std::vector<int64_t>> & replica_groups)1390 int64_t ReplicaCount(const std::vector<std::vector<int64_t>>& replica_groups) {
1391 int64_t replica_count = 0;
1392 for (auto group : replica_groups) {
1393 replica_count += group.size();
1394 }
1395 return replica_count;
1396 }
1397
MakeCollectiveCommOpComputation(std::vector<std::vector<int64_t>> replica_groups,std::optional<int64_t> replica_count,std::optional<int64_t> num_partitions,absl::string_view other_attributes,absl::string_view template_str)1398 StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
1399 std::vector<std::vector<int64_t>> replica_groups,
1400 std::optional<int64_t> replica_count, std::optional<int64_t> num_partitions,
1401 absl::string_view other_attributes, absl::string_view template_str) {
1402 HloModuleConfig config;
1403 config.set_replica_count(
1404 replica_count.value_or(ReplicaCount(replica_groups)));
1405 config.set_num_partitions(num_partitions.value_or(1));
1406 return ParseAndReturnUnverifiedModule(
1407 absl::StrReplaceAll(
1408 template_str,
1409 {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
1410 {"OTHER_ATTRIBUTES", other_attributes.empty()
1411 ? ""
1412 : absl::StrCat(",", other_attributes)}}),
1413 config);
1414 }
1415
MakeAllReduceComputation(std::vector<std::vector<int64_t>> replica_groups,std::optional<int64_t> replica_count=std::nullopt,std::optional<int64_t> num_partitions=std::nullopt,absl::string_view other_attributes="")1416 StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
1417 std::vector<std::vector<int64_t>> replica_groups,
1418 std::optional<int64_t> replica_count = std::nullopt,
1419 std::optional<int64_t> num_partitions = std::nullopt,
1420 absl::string_view other_attributes = "") {
1421 const char* kTemplate = R"(
1422 HloModule test
1423 add {
1424 x = f32[] parameter(0)
1425 y = f32[] parameter(1)
1426 ROOT add = f32[] add(x, y)
1427 }
1428 ENTRY entry {
1429 p = f32[128]{0} parameter(0)
1430 crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
1431 OTHER_ATTRIBUTES
1432 })";
1433 return MakeCollectiveCommOpComputation(replica_groups, replica_count,
1434 num_partitions, other_attributes,
1435 kTemplate);
1436 }
1437
TEST_F(HloVerifierTest,AllReduce_NoReplicaGroupsOK)1438 TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
1439 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
1440 TF_ASSERT_OK(verifier().Run(module.get()).status());
1441 }
1442
TEST_F(HloVerifierTest,AllReduce_DifferentGroupSizesOk)1443 TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
1444 TF_ASSERT_OK_AND_ASSIGN(auto module,
1445 MakeAllReduceComputation({{0}, {1, 3}, {2}}));
1446 TF_ASSERT_OK(verifier().Run(module.get()).status());
1447 }
1448
TEST_F(HloVerifierTest,AllReduce_EmptyReplicaGroup)1449 TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
1450 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
1451 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1452 HasSubstr("empty replica group"));
1453 }
1454
TEST_F(HloVerifierTest,AllReduce_RepeatedReplicaId)1455 TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
1456 TF_ASSERT_OK_AND_ASSIGN(auto module,
1457 MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
1458 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1459 HasSubstr("Replica 0 is repeated"));
1460 }
1461
TEST_F(HloVerifierTest,AllReduce_MissingReplicaId)1462 TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
1463 TF_ASSERT_OK_AND_ASSIGN(auto module,
1464 MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
1465 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1466 HasSubstr("Replica 4 is not named"));
1467 }
1468
TEST_F(HloVerifierTest,AllReduce_NotEnougReplicasInGroupConfig)1469 TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
1470 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
1471 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1472 HasSubstr("In kCrossReplica mode, replica groups should contain "
1473 "8 replicas, but found 2"));
1474 }
1475
TEST_F(HloVerifierTest,AllReduce_TooManyReplicasInGroupConfig)1476 TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
1477 TF_ASSERT_OK_AND_ASSIGN(auto module,
1478 MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
1479 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1480 HasSubstr("In kCrossReplica mode, replica groups should contain "
1481 "2 replicas, but found 4"));
1482 }
1483
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Invalid)1484 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
1485 TF_ASSERT_OK_AND_ASSIGN(
1486 auto module,
1487 MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
1488 EXPECT_THAT(
1489 verifier().Run(module.get()).status().error_message(),
1490 HasSubstr(
1491 "In kCrossReplicaAndPartition mode, replica groups should contain "
1492 "2 replicas, but found 4"));
1493 }
1494
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Valid)1495 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
1496 TF_ASSERT_OK_AND_ASSIGN(
1497 auto module,
1498 MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
1499 TF_ASSERT_OK(verifier().Run(module.get()).status());
1500 }
1501
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Invalid)1502 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
1503 TF_ASSERT_OK_AND_ASSIGN(
1504 auto module,
1505 MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
1506 "channel_id=1, use_global_device_ids=true"));
1507 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1508 HasSubstr("In kFlattenedID mode, replica groups should contain "
1509 "2 flattened IDs, but found 4"));
1510 }
1511
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Valid)1512 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
1513 TF_ASSERT_OK_AND_ASSIGN(
1514 auto module,
1515 MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
1516 "channel_id=1, use_global_device_ids=true"));
1517 TF_ASSERT_OK(verifier().Run(module.get()).status());
1518 }
1519
TEST_F(HloVerifierTest,AllReduceStartAndDone)1520 TEST_F(HloVerifierTest, AllReduceStartAndDone) {
1521 const char* const kModuleStr = R"(
1522 HloModule test
1523 add {
1524 x = f32[] parameter(0)
1525 y = f32[] parameter(1)
1526 ROOT add = f32[] add(x, y)
1527 }
1528 ENTRY entry {
1529 p0 = f32[2,3] parameter(0)
1530 start = f32[2,3] all-reduce-start(p0), to_apply=add
1531 ROOT done = f32[2,3] all-reduce-done(start)
1532 }
1533 )";
1534 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1535 ParseAndReturnUnverifiedModule(kModuleStr));
1536
1537 auto status = verifier().Run(module.get()).status();
1538 ASSERT_TRUE(status.ok());
1539 }
1540
TEST_F(HloVerifierTest,AllReduceStartAndDoneWrongType)1541 TEST_F(HloVerifierTest, AllReduceStartAndDoneWrongType) {
1542 const char* const kModuleStr = R"(
1543 HloModule test
1544 add {
1545 x = f32[] parameter(0)
1546 y = f32[] parameter(1)
1547 ROOT add = f32[] add(x, y)
1548 }
1549 ENTRY entry {
1550 p0 = f32[2,3] parameter(0)
1551 start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1552 ROOT done = f32[2,3] all-reduce-done(start)
1553 }
1554 )";
1555 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1556 ParseAndReturnUnverifiedModule(kModuleStr));
1557
1558 auto status = verifier().Run(module.get()).status();
1559 EXPECT_THAT(status.error_message(),
1560 HasSubstr("Expected instruction to have shape equal to "
1561 "f32[2,3]"));
1562 }
1563
TEST_F(HloVerifierTest,AllReduceStartAndMultipleDone)1564 TEST_F(HloVerifierTest, AllReduceStartAndMultipleDone) {
1565 const char* const kModuleStr = R"(
1566 HloModule test
1567 add {
1568 x = f32[] parameter(0)
1569 y = f32[] parameter(1)
1570 ROOT add = f32[] add(x, y)
1571 }
1572 ENTRY entry {
1573 p0 = f32[2,3] parameter(0)
1574 start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1575 done1 = f32[2,3] all-reduce-done(start)
1576 ROOT done2 = f32[2,3] all-reduce-done(start)
1577 }
1578 )";
1579 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1580 ParseAndReturnUnverifiedModule(kModuleStr));
1581
1582 auto status = verifier().Run(module.get()).status();
1583 ASSERT_FALSE(status.ok());
1584 EXPECT_THAT(
1585 status.error_message(),
1586 HasSubstr("all-reduce-start instruction requires one consumer, found 2"));
1587 }
1588
TEST_F(HloVerifierTest,AllReduceDoneWithoutStart)1589 TEST_F(HloVerifierTest, AllReduceDoneWithoutStart) {
1590 const char* const kModuleStr = R"(
1591 HloModule test
1592 ENTRY entry {
1593 p0 = f32[2,3] parameter(0)
1594 p1 = u32[] parameter(1)
1595 tuple = (f32[2,3], f32[2,3]) tuple(p0, p0, p1, p1)
1596 ROOT done = f32[2,3] all-reduce-done(tuple)
1597 }
1598 )";
1599 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1600 ParseAndReturnUnverifiedModule(kModuleStr));
1601
1602 auto status = verifier().Run(module.get()).status();
1603 ASSERT_FALSE(status.ok());
1604 EXPECT_THAT(status.error_message(),
1605 HasSubstr("The operand of a all-reduce-done instruction "
1606 "needs to be all-reduce-start, found tuple"));
1607 }
1608
MakeAllToAllComputation(std::vector<std::vector<int64_t>> replica_groups,std::optional<int64_t> replica_count=std::nullopt,std::optional<int64_t> num_partitions=std::nullopt,absl::string_view other_attributes="")1609 StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
1610 std::vector<std::vector<int64_t>> replica_groups,
1611 std::optional<int64_t> replica_count = std::nullopt,
1612 std::optional<int64_t> num_partitions = std::nullopt,
1613 absl::string_view other_attributes = "") {
1614 const char* kTemplate = R"(
1615 HloModule test
1616 ENTRY entry {
1617 p0 = f32[128]{0} parameter(0)
1618 p1 = f32[128]{0} parameter(1)
1619 a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
1620 OTHER_ATTRIBUTES
1621 })";
1622 return MakeCollectiveCommOpComputation(replica_groups, replica_count,
1623 num_partitions, other_attributes,
1624 kTemplate);
1625 }
1626
TEST_F(HloVerifierTest,AllToAll_NoReplicaGroupsOK)1627 TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
1628 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
1629 TF_ASSERT_OK(verifier().Run(module.get()).status());
1630 }
1631
TEST_F(HloVerifierTest,AllToAll_EmptyReplicaGroup)1632 TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
1633 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
1634 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1635 HasSubstr("cannot have an empty replica group"));
1636 }
1637
TEST_F(HloVerifierTest,AllToAll_RepeatedReplicaId)1638 TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
1639 TF_ASSERT_OK_AND_ASSIGN(auto module,
1640 MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
1641 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1642 HasSubstr("Replica 0 is repeated"));
1643 }
1644
TEST_F(HloVerifierTest,AllToAll_MissingReplicaId)1645 TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
1646 TF_ASSERT_OK_AND_ASSIGN(auto module,
1647 MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
1648 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1649 HasSubstr("Replica 4 is not named"));
1650 }
1651
TEST_F(HloVerifierTest,AllToAll_UniformSizeOfReplicasInGroup)1652 TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
1653 TF_ASSERT_OK_AND_ASSIGN(auto module,
1654 MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
1655 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1656 HasSubstr("Replica groups expected to be of uniform size"));
1657 }
1658
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Invalid)1659 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
1660 TF_ASSERT_OK_AND_ASSIGN(
1661 auto module,
1662 MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
1663 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1664 HasSubstr("In kCrossPartition mode, replica groups should "
1665 "contain 2 partitions, but found 4"));
1666 }
1667
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Valid)1668 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
1669 TF_ASSERT_OK_AND_ASSIGN(
1670 auto module,
1671 MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
1672 TF_ASSERT_OK(verifier().Run(module.get()).status());
1673 }
1674
TEST_F(HloVerifierTest,AllToAll_LayoutConstrained)1675 TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
1676 const char* const kModuleStr = R"(
1677 HloModule test
1678 ENTRY entry {
1679 p0 = f32[128,4]{0,1} parameter(0)
1680 p1 = f32[128,4]{1,0} parameter(1)
1681 ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
1682 replica_groups={{0,1}}
1683 }
1684 )";
1685 HloModuleConfig config;
1686 config.set_replica_count(2);
1687 TF_ASSERT_OK_AND_ASSIGN(auto module,
1688 ParseAndReturnUnverifiedModule(kModuleStr, config));
1689 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1690 HasSubstr("HLO all-to-all has operands with different shapes"));
1691 }
1692
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTwice)1693 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
1694 const char* const kModuleStr = R"(
1695 HloModule test
1696 ENTRY entry {
1697 p0 = f32[128] parameter(0)
1698 ROOT permute = f32[128] collective-permute(p0),
1699 source_target_pairs={{0,1}, {0,2}, {1,0}}
1700 }
1701 )";
1702 HloModuleConfig config;
1703 config.set_replica_count(3);
1704 TF_ASSERT_OK_AND_ASSIGN(auto module,
1705 ParseAndReturnUnverifiedModule(kModuleStr, config));
1706 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1707 HasSubstr("Source 0 appears more than once"));
1708 }
1709
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTwice)1710 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
1711 const char* const kModuleStr = R"(
1712 HloModule test
1713 ENTRY entry {
1714 p0 = f32[128] parameter(0)
1715 ROOT permute = f32[128] collective-permute(p0),
1716 source_target_pairs={{0,2}, {1,2}, {2,0}}
1717 }
1718 )";
1719 TF_ASSERT_OK_AND_ASSIGN(auto module,
1720 ParseAndReturnUnverifiedModule(kModuleStr));
1721 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1722 HasSubstr("Target 2 appears more than once"));
1723 }
1724
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTooManyTimes)1725 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTooManyTimes) {
1726 const char* const kModuleStr = R"(
1727 HloModule test
1728 ENTRY entry {
1729 replica_id = u32[] replica-id()
1730 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1731 constant.1 = u32[] constant(1000)
1732 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1733 constant.2 = s32[] constant(0)
1734 constant.3 = s32[] constant(1)
1735 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1736 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1737 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1738 ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{0,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1739 }
1740 )";
1741 TF_ASSERT_OK_AND_ASSIGN(auto module,
1742 ParseAndReturnUnverifiedModule(kModuleStr));
1743 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1744 HasSubstr("Source 0 appears more than 2 times in instruction's "
1745 "source-target pairs:"));
1746 }
1747
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTooManyTimes)1748 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTooManyTimes) {
1749 const char* const kModuleStr = R"(
1750 HloModule test
1751 ENTRY entry {
1752 replica_id = u32[] replica-id()
1753 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1754 constant.1 = u32[] constant(1000)
1755 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1756 constant.2 = s32[] constant(0)
1757 constant.3 = s32[] constant(1)
1758 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1759 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1760 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1761 ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,3},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1762 }
1763 )";
1764 TF_ASSERT_OK_AND_ASSIGN(auto module,
1765 ParseAndReturnUnverifiedModule(kModuleStr));
1766 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1767 HasSubstr("Target 3 appears more than 2 times in instruction's "
1768 "source-target pairs:"));
1769 }
1770
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingSourceTarget)1771 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingSourceTarget) {
1772 const char* const kModuleStr = R"(
1773 HloModule test
1774 ENTRY entry {
1775 replica_id = u32[] replica-id()
1776 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1777 constant.1 = u32[] constant(1000)
1778 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1779 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1780 constant.2 = s32[] constant(0)
1781 constant.3 = s32[] constant(1)
1782 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1783 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1784 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1785 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1786 constant.4 = s32[] constant(2)
1787 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1788 tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1789 tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1790 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1791 }
1792 )";
1793 TF_ASSERT_OK_AND_ASSIGN(auto module,
1794 ParseAndReturnUnverifiedModule(kModuleStr));
1795 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1796 HasSubstr("Unmatching input buffers and output buffers"));
1797 }
1798
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingInputAndInputOffset)1799 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingInputAndInputOffset) {
1800 const char* const kModuleStr = R"(
1801 HloModule test
1802 ENTRY entry {
1803 replica_id = u32[] replica-id()
1804 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1805 constant.1 = u32[] constant(1000)
1806 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1807 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1808 constant.2 = s32[] constant(0)
1809 constant.3 = s32[] constant(1)
1810 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1811 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1812 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1813 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1814 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1815 constant.4 = s32[] constant(2)
1816 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1817 tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1818 tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1819 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (s32[],s32[],s32[]) tuple.3, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1820 }
1821 )";
1822 TF_ASSERT_OK_AND_ASSIGN(auto module,
1823 ParseAndReturnUnverifiedModule(kModuleStr));
1824 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1825 HasSubstr("Unmatching input buffers and input offset."));
1826 }
1827
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingOutputAndOutputOffset)1828 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingOutputAndOutputOffset) {
1829 const char* const kModuleStr = R"(
1830 HloModule test
1831 ENTRY entry {
1832 replica_id = u32[] replica-id()
1833 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1834 constant.1 = u32[] constant(1000)
1835 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1836 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1837 constant.2 = s32[] constant(0)
1838 constant.3 = s32[] constant(1)
1839 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1840 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1841 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1842 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1843 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1844 constant.4 = s32[] constant(2)
1845 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1846 tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
1847 tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
1848 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (s32[],s32[],s32[]) tuple.2), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1849 }
1850 )";
1851 TF_ASSERT_OK_AND_ASSIGN(auto module,
1852 ParseAndReturnUnverifiedModule(kModuleStr));
1853 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1854 HasSubstr("Unmatching output buffers and output offset."));
1855 }
1856
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaSourceOOR)1857 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaSourceOOR) {
1858 const char* const kModuleStr = R"(
1859 HloModule test
1860 ENTRY entry {
1861 p0 = f32[128] parameter(0)
1862 ROOT permute = f32[128] collective-permute(p0),
1863 source_target_pairs={{5,2}, {1,2}, {2,0}}
1864 }
1865 )";
1866 HloModuleConfig config;
1867 config.set_replica_count(3);
1868 TF_ASSERT_OK_AND_ASSIGN(auto module,
1869 ParseAndReturnUnverifiedModule(kModuleStr, config));
1870 const std::string error_message =
1871 verifier().Run(module.get()).status().error_message();
1872 EXPECT_THAT(error_message, HasSubstr("Source 5"));
1873 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1874 }
1875
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaTargetOOR)1876 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaTargetOOR) {
1877 const char* const kModuleStr = R"(
1878 HloModule test
1879 ENTRY entry {
1880 p0 = f32[128] parameter(0)
1881 ROOT permute = f32[128] collective-permute(p0),
1882 source_target_pairs={{0,1}, {1,2}, {2,7}}
1883 }
1884 )";
1885 HloModuleConfig config;
1886 config.set_replica_count(3);
1887 TF_ASSERT_OK_AND_ASSIGN(auto module,
1888 ParseAndReturnUnverifiedModule(kModuleStr, config));
1889 const std::string error_message =
1890 verifier().Run(module.get()).status().error_message();
1891 EXPECT_THAT(error_message, HasSubstr("Target 7"));
1892 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1893 }
1894
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionSourceOOR)1895 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionSourceOOR) {
1896 const char* const kModuleStr = R"(
1897 HloModule test
1898 ENTRY entry {
1899 p0 = f32[128] parameter(0)
1900 ROOT permute = f32[128] collective-permute(p0),
1901 source_target_pairs={{5,2}, {1,2}, {2,0}}, channel_id=1
1902 }
1903 )";
1904 HloModuleConfig config;
1905 config.set_num_partitions(3);
1906 TF_ASSERT_OK_AND_ASSIGN(auto module,
1907 ParseAndReturnUnverifiedModule(kModuleStr, config));
1908 const std::string error_message =
1909 verifier().Run(module.get()).status().error_message();
1910 EXPECT_THAT(error_message, HasSubstr("Source 5"));
1911 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1912 }
1913
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionTargetOOR)1914 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) {
1915 const char* const kModuleStr = R"(
1916 HloModule test
1917 ENTRY entry {
1918 p0 = f32[128] parameter(0)
1919 ROOT permute = f32[128] collective-permute(p0),
1920 source_target_pairs={{0,2}, {1,7}, {2,0}}, channel_id=1
1921 }
1922 )";
1923 HloModuleConfig config;
1924 config.set_num_partitions(3);
1925 TF_ASSERT_OK_AND_ASSIGN(auto module,
1926 ParseAndReturnUnverifiedModule(kModuleStr, config));
1927 const std::string error_message =
1928 verifier().Run(module.get()).status().error_message();
1929 EXPECT_THAT(error_message, HasSubstr("Target 7"));
1930 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1931 }
1932
TEST_F(HloVerifierTest,FusionShapeVerifier)1933 TEST_F(HloVerifierTest, FusionShapeVerifier) {
1934 const char* const kModuleStr = R"(
1935 HloModule test
1936
1937 fused_computation {
1938 ROOT p0 = f32[10,10] parameter(0)
1939 }
1940
1941 ENTRY entry {
1942 p0 = f32[10,10] parameter(0)
1943 ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
1944 }
1945 )";
1946 TF_ASSERT_OK_AND_ASSIGN(auto module,
1947 ParseAndReturnUnverifiedModule(kModuleStr));
1948 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1949 HasSubstr("Fused computation shape"));
1950 }
1951
TEST_F(HloVerifierTest,FusionThreadVerifier)1952 TEST_F(HloVerifierTest, FusionThreadVerifier) {
1953 const char* const kModuleStr = R"(
1954 HloModule test
1955
1956 fused_computation {
1957 ROOT p0 = f32[8,12] parameter(0)
1958 }, execution_thread="parallel_thread"
1959
1960 ENTRY entry {
1961 p0 = f32[8,12] parameter(0)
1962 ROOT out = f32[8,12] fusion(p0), kind=kInput, calls=fused_computation
1963 }
1964 )";
1965 TF_ASSERT_OK_AND_ASSIGN(auto module,
1966 ParseAndReturnUnverifiedModule(kModuleStr));
1967 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1968 HasSubstr("expects parent computation thread name same as called "
1969 "computation's thread name"));
1970 }
1971
TEST_F(HloVerifierTest,FusionNestedComputationThreadVerifier)1972 TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) {
1973 const char* const kModuleStr = R"(
1974 HloModule test
1975
1976 add {
1977 lhs = f32[] parameter(0)
1978 rhs = f32[] parameter(1)
1979 ROOT add = f32[] add(lhs, rhs)
1980 }, execution_thread="parallel_thread"
1981
1982 fused_computation {
1983 p0 = f32[8,12] parameter(0)
1984 p1 = f32[8,12] parameter(1)
1985 crs0 = f32[8,12] all-reduce(p1), replica_groups={}, to_apply=add
1986 ROOT result = add(p0, crs0)
1987 }
1988
1989 ENTRY entry {
1990 p0 = f32[8,12] parameter(0)
1991 p1 = f32[8,12] parameter(1)
1992 ROOT out = f32[8,12] fusion(p0, p1), kind=kInput, calls=fused_computation
1993 }
1994 )";
1995 TF_ASSERT_OK_AND_ASSIGN(auto module,
1996 ParseAndReturnUnverifiedModule(kModuleStr));
1997 EXPECT_THAT(
1998 verifier().Run(module.get()).status().error_message(),
1999 HasSubstr("Nested computations expects same computation's thread name"));
2000 }
2001
TEST_F(HloVerifierTest,AllReduceVerifier)2002 TEST_F(HloVerifierTest, AllReduceVerifier) {
2003 const char* const kModuleStr = R"(
2004 HloModule test
2005
2006 add {
2007 lhs = f32[] parameter(0)
2008 rhs = f32[] parameter(1)
2009 ROOT add = f32[] add(lhs, rhs)
2010 }
2011
2012 ENTRY entry {
2013 input = f32[8,12]{0,1} parameter(0)
2014 crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
2015 crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
2016 constrain_layout=true
2017 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
2018 }
2019 )";
2020 TF_ASSERT_OK_AND_ASSIGN(auto module,
2021 ParseAndReturnUnverifiedModule(kModuleStr));
2022 EXPECT_THAT(
2023 verifier().Run(module.get()).status().error_message(),
2024 HasSubstr("mix of layout constrained and unconstrained AllReduce"));
2025 }
2026
TEST_F(HloVerifierTest,ChannelVerifier)2027 TEST_F(HloVerifierTest, ChannelVerifier) {
2028 const char* const kModuleStr = R"(
2029 HloModule test
2030
2031 add {
2032 lhs = f32[] parameter(0)
2033 rhs = f32[] parameter(1)
2034 ROOT add = f32[] add(lhs, rhs)
2035 }
2036
2037 ENTRY entry {
2038 %input = f32[8,12] parameter(0)
2039 %token0 = token[] after-all()
2040 %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
2041 %send-done = token[] send-done(%send), channel_id=1
2042 %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
2043 channel_id=1
2044 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
2045 }
2046 )";
2047 TF_ASSERT_OK_AND_ASSIGN(auto module,
2048 ParseAndReturnUnverifiedModule(kModuleStr));
2049 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
2050 HasSubstr("used for different types of channel instructions"));
2051 }
2052
TEST_F(HloVerifierTest,CollectiveChannelVerifier)2053 TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
2054 const char* const kModuleStr = R"(
2055 HloModule test
2056
2057 add {
2058 lhs = f32[] parameter(0)
2059 rhs = f32[] parameter(1)
2060 ROOT add = f32[] add(lhs, rhs)
2061 }
2062
2063 ENTRY entry {
2064 %input = f32[8,12] parameter(0)
2065 %permute = f32[8,12] collective-permute(%input),
2066 source_target_pairs={{0,1},{1,0}}, channel_id=1
2067 %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
2068 channel_id=1
2069 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
2070 }
2071 )";
2072 TF_ASSERT_OK_AND_ASSIGN(auto module,
2073 ParseAndReturnUnverifiedModule(kModuleStr));
2074 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
2075 HasSubstr("used for different types of channel instructions"));
2076 }
2077
TEST_F(HloVerifierTestLayoutSensitive,CollectivePermuteStartAndDone)2078 TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
2079 const char* const kModuleStr = R"(
2080 HloModule Module
2081
2082 ENTRY CollectivePermuteStartAndDone {
2083 p0 = f32[2,3]{1,0:S(1)} parameter(0)
2084 collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
2085 ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2086 }
2087 )";
2088 TF_ASSERT_OK_AND_ASSIGN(auto module,
2089 ParseAndReturnUnverifiedModule(kModuleStr));
2090
2091 auto status = verifier().Run(module.get()).status();
2092 ASSERT_TRUE(status.ok());
2093 }
2094
TEST_F(HloVerifierTest,CollectivePermuteStartAndDoneWrongType)2095 TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
2096 const char* const kModuleStr = R"(
2097 HloModule Module
2098
2099 ENTRY CollectivePermuteStartAndDoneWrongType {
2100 p0 = f32[2,3]{1,0:S(1)} parameter(0)
2101 collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
2102 ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2103 }
2104 )";
2105 TF_ASSERT_OK_AND_ASSIGN(auto module,
2106 ParseAndReturnUnverifiedModule(kModuleStr));
2107
2108 auto status = verifier().Run(module.get()).status();
2109 ASSERT_FALSE(status.ok());
2110 EXPECT_THAT(status.error_message(),
2111 HasSubstr("Expected instruction to have shape equal to "
2112 "(f32[2,3], f32[2,3], u32[], u32[])"));
2113 }
2114
TEST_F(HloVerifierTest,CollectivePermuteStartAndMultipleDone)2115 TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
2116 const char* const kModuleStr = R"(
2117 HloModule Module
2118
2119 ENTRY CollectivePermuteStartAndMultipleDone {
2120 p0 = f32[2,3]{1,0:S(1)} parameter(0)
2121 collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
2122 collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2123 ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2124 }
2125 )";
2126 TF_ASSERT_OK_AND_ASSIGN(auto module,
2127 ParseAndReturnUnverifiedModule(kModuleStr));
2128
2129 auto status = verifier().Run(module.get()).status();
2130 ASSERT_FALSE(status.ok());
2131 EXPECT_THAT(
2132 status.error_message(),
2133 HasSubstr("collective-permute-start instruction requires one consumer, "
2134 "found 2"));
2135 }
2136
TEST_F(HloVerifierTest,CollectivePermuteDoneNoCollectivePermuteStart)2137 TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
2138 const char* const kModuleStr = R"(
2139 HloModule Module
2140
2141 ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
2142 p0 = f32[2,3]{1,0:S(1)} parameter(0)
2143 p1 = f32[2,3]{1,0:S(1)} parameter(1)
2144 p2 = u32[] parameter(2)
2145 p3 = u32[] parameter(3)
2146 tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
2147 ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
2148 }
2149 )";
2150 TF_ASSERT_OK_AND_ASSIGN(auto module,
2151 ParseAndReturnUnverifiedModule(kModuleStr));
2152
2153 auto status = verifier().Run(module.get()).status();
2154 ASSERT_FALSE(status.ok());
2155 EXPECT_THAT(status.error_message(),
2156 HasSubstr("The operand of a collective-permute-done instruction "
2157 "needs to be collective-permute-start, found tuple"));
2158 }
2159
TEST_F(HloVerifierTest,ComparisonTypeFloat)2160 TEST_F(HloVerifierTest, ComparisonTypeFloat) {
2161 const char* const hlo_string = R"(
2162 HloModule Module
2163
2164 ENTRY RngOperandElementTypesNotMatch {
2165 p0 = f32[] parameter(0)
2166 ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
2167 }
2168 )";
2169 TF_ASSERT_OK_AND_ASSIGN(auto module,
2170 ParseAndReturnUnverifiedModule(hlo_string));
2171
2172 auto status = verifier().Run(module.get()).status();
2173 ASSERT_FALSE(status.ok());
2174 EXPECT_THAT(status.error_message(),
2175 HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
2176 }
2177
TEST_F(HloVerifierTest,ComparisonTypeSigned)2178 TEST_F(HloVerifierTest, ComparisonTypeSigned) {
2179 const char* const hlo_string = R"(
2180 HloModule Module
2181
2182 ENTRY RngOperandElementTypesNotMatch {
2183 p0 = s32[] parameter(0)
2184 ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
2185 }
2186 )";
2187 TF_ASSERT_OK_AND_ASSIGN(auto module,
2188 ParseAndReturnUnverifiedModule(hlo_string));
2189
2190 auto status = verifier().Run(module.get()).status();
2191 ASSERT_FALSE(status.ok());
2192 EXPECT_THAT(status.error_message(),
2193 HasSubstr("Expected comparison type SIGNED"));
2194 }
2195
TEST_F(HloVerifierTest,ComparisonTypeUnsigned)2196 TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
2197 const char* const hlo_string = R"(
2198 HloModule Module
2199
2200 ENTRY RngOperandElementTypesNotMatch {
2201 p0 = u32[] parameter(0)
2202 ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
2203 }
2204 )";
2205 TF_ASSERT_OK_AND_ASSIGN(auto module,
2206 ParseAndReturnUnverifiedModule(hlo_string));
2207
2208 auto status = verifier().Run(module.get()).status();
2209 ASSERT_FALSE(status.ok());
2210 EXPECT_THAT(status.error_message(),
2211 HasSubstr("Expected comparison type UNSIGNED"));
2212 }
2213
TEST_F(HloVerifierTest,ComparisonTypePred)2214 TEST_F(HloVerifierTest, ComparisonTypePred) {
2215 const char* const hlo_string = R"(
2216 HloModule Module
2217
2218 ENTRY RngOperandElementTypesNotMatch {
2219 p0 = pred[] parameter(0)
2220 ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
2221 }
2222 )";
2223 TF_ASSERT_OK_AND_ASSIGN(auto module,
2224 ParseAndReturnUnverifiedModule(hlo_string));
2225
2226 auto status = verifier().Run(module.get()).status();
2227 ASSERT_FALSE(status.ok());
2228 EXPECT_THAT(status.error_message(),
2229 HasSubstr("Expected comparison type UNSIGNED"));
2230 }
2231
TEST_F(HloVerifierTest,UseGlobalDeviceIdsEmptyReplicaGroup)2232 TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
2233 const char* const hlo_string = R"(
2234 HloModule Module
2235 add {
2236 lhs = f32[] parameter(0)
2237 rhs = f32[] parameter(1)
2238 ROOT add = f32[] add(lhs, rhs)
2239 }
2240
2241 ENTRY CRS {
2242 input = f32[8]{0} parameter(0)
2243 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
2244 use_global_device_ids=true, to_apply=add
2245 })";
2246 TF_ASSERT_OK_AND_ASSIGN(auto module,
2247 ParseAndReturnUnverifiedModule(hlo_string));
2248
2249 auto status = verifier().Run(module.get()).status();
2250 ASSERT_FALSE(status.ok());
2251 EXPECT_THAT(
2252 status.error_message(),
2253 HasSubstr("Replica groups must be specified in flattened-id mode"));
2254 }
2255
TEST_F(HloVerifierTest,InvalidChannelIDandUseGlobalDeviceIDs)2256 TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
2257 const char* const hlo_string = R"(
2258 HloModule Module
2259 add {
2260 lhs = f32[] parameter(0)
2261 rhs = f32[] parameter(1)
2262 ROOT add = f32[] add(lhs, rhs)
2263 }
2264
2265 ENTRY CRS {
2266 input = f32[8]{0} parameter(0)
2267 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
2268 use_global_device_ids=true, to_apply=add
2269 })";
2270 TF_ASSERT_OK_AND_ASSIGN(auto module,
2271 ParseAndReturnUnverifiedModule(hlo_string));
2272
2273 auto status = verifier().Run(module.get()).status();
2274 ASSERT_FALSE(status.ok());
2275 EXPECT_THAT(
2276 status.error_message(),
2277 HasSubstr(
2278 "Invalid combination of has_channel_id and use_global_device_ids"));
2279 }
2280
TEST_F(HloVerifierTest,ReduceScatterInvalidOutputSize0)2281 TEST_F(HloVerifierTest, ReduceScatterInvalidOutputSize0) {
2282 const char* const hlo_string = R"(
2283 HloModule Module
2284 add {
2285 lhs = f32[] parameter(0)
2286 rhs = f32[] parameter(1)
2287 ROOT add = f32[] add(lhs, rhs)
2288 }
2289
2290 ENTRY CRS {
2291 input = f32[8]{0} parameter(0)
2292 ROOT crs = f32[8]{0} reduce-scatter(input), replica_groups={{0,1}},
2293 to_apply=add, dimensions={0}
2294 })";
2295 TF_ASSERT_OK_AND_ASSIGN(auto module,
2296 ParseAndReturnUnverifiedModule(hlo_string));
2297
2298 auto status = verifier().Run(module.get()).status();
2299 ASSERT_FALSE(status.ok());
2300 EXPECT_THAT(status.error_message(),
2301 HasSubstr("shard_count = 1, subgroup_size = 2"));
2302 }
2303
TEST_F(HloVerifierTest,ReduceScatterInvalidScatterDim)2304 TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) {
2305 const char* const hlo_string = R"(
2306 HloModule Module
2307 add {
2308 lhs = f32[] parameter(0)
2309 rhs = f32[] parameter(1)
2310 ROOT add = f32[] add(lhs, rhs)
2311 }
2312
2313 ENTRY CRS {
2314 input = f32[8]{0} parameter(0)
2315 ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}},
2316 to_apply=add, dimensions={1}
2317 })";
2318 TF_ASSERT_OK_AND_ASSIGN(auto module,
2319 ParseAndReturnUnverifiedModule(hlo_string));
2320
2321 auto status = verifier().Run(module.get()).status();
2322 ASSERT_FALSE(status.ok());
2323 EXPECT_THAT(
2324 status.error_message(),
2325 HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()"));
2326 }
2327
TEST_F(HloVerifierTest,ReduceScatterNonUniformGroups)2328 TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
2329 const char* const hlo_string = R"(
2330 HloModule Module
2331 add {
2332 lhs = f32[] parameter(0)
2333 rhs = f32[] parameter(1)
2334 ROOT add = f32[] add(lhs, rhs)
2335 }
2336
2337 ENTRY CRS {
2338 input = f32[8]{0} parameter(0)
2339 ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}, {2,3,4}},
2340 to_apply=add, dimensions={0}
2341 })";
2342 TF_ASSERT_OK_AND_ASSIGN(auto module,
2343 ParseAndReturnUnverifiedModule(hlo_string));
2344
2345 auto status = verifier().Run(module.get()).status();
2346 ASSERT_FALSE(status.ok());
2347 EXPECT_THAT(status.error_message(),
2348 HasSubstr("Replica groups expected to be of uniform size"));
2349 }
2350
TEST_F(HloVerifierTest,VerifyBroadcastDimensionsOrder)2351 TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) {
2352 const char* const hlo = R"(
2353 HloModule module
2354
2355 ENTRY computation {
2356 mul = f32[32,32,32]{2,1,0} parameter(0)
2357 ROOT broadcast = f32[32,32,32,32]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
2358 }
2359 )";
2360
2361 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2362 auto status = HloVerifier{HloVerifierOpts{}.VerifyBroadcastDimensionsOrder()}
2363 .Run(module.get())
2364 .status();
2365 ASSERT_FALSE(status.ok());
2366 EXPECT_THAT(status.error_message(),
2367 HasSubstr("Broadcast dimensions should be ordered"));
2368 }
2369
TEST_F(HloVerifierTest,VerifyBroadcastDimensionsOrderOK)2370 TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrderOK) {
2371 const char* const hlo = R"(
2372 HloModule module
2373
2374 ENTRY computation {
2375 mul = f32[4,5] parameter(0)
2376 ROOT broadcast = f32[4,3,2,5] broadcast(mul), dimensions={0,3}
2377 }
2378 )";
2379
2380 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2381 TF_ASSERT_OK(HloVerifier{HloVerifierOpts{}.VerifyBroadcastDimensionsOrder()}
2382 .Run(module.get())
2383 .status());
2384 }
2385
TEST_F(HloVerifierTest,ReshapeIsNotBitcast)2386 TEST_F(HloVerifierTest, ReshapeIsNotBitcast) {
2387 const char* const hlo = R"(
2388 HloModule Module
2389
2390 ENTRY main {
2391 p = f32[8,3]{1,0} parameter(0)
2392 ROOT r = f32[4,2,3]{0,1,2} reshape(p)
2393 }
2394 )";
2395
2396 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2397 auto status =
2398 HloVerifier{
2399 HloVerifierOpts{}.MakeLayoutSensitive().VerifyReshapeIsBitcast()}
2400 .Run(module.get())
2401 .status();
2402 ASSERT_FALSE(status.ok());
2403 EXPECT_THAT(status.error_message(),
2404 HasSubstr("Reshape should be a physical bitcast"));
2405 }
2406
TEST_F(HloVerifierTest,ReshapeIsBitcast)2407 TEST_F(HloVerifierTest, ReshapeIsBitcast) {
2408 const char* const hlo = R"(
2409 HloModule Module
2410
2411 ENTRY main {
2412 p = f32[8]{0} parameter(0)
2413 ROOT r = f32[4,2]{1,0} reshape(p)
2414 }
2415 )";
2416
2417 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2418 TF_ASSERT_OK(HloVerifier{
2419 HloVerifierOpts{}.MakeLayoutSensitive().VerifyReshapeIsBitcast()}
2420 .Run(module.get())
2421 .status());
2422 }
2423
TEST_F(HloVerifierTest,VerifyCustomCallThread)2424 TEST_F(HloVerifierTest, VerifyCustomCallThread) {
2425 const char* const hlo = R"(
2426 HloModule module
2427 %call_body (prev.2: s32[]) -> pred[] {
2428 %constant.1 = s32[] constant(5)
2429 %prev.2 = s32[] parameter(0)
2430 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
2431 }, execution_thread="parallel_thread"
2432
2433 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
2434 %constant.2 = s32[] constant(0)
2435 ROOT %custom = s32[] custom-call(s32[] %constant.2), custom_call_target="MyCustomCall", to_apply=%call_body
2436 }
2437 )";
2438
2439 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2440 auto status =
2441 HloVerifier{
2442 HloVerifierOpts{}.VerifyCustomCallNestedComputationThreadName()}
2443 .Run(module.get())
2444 .status();
2445 ASSERT_FALSE(status.ok());
2446 EXPECT_THAT(status.error_message(),
2447 HasSubstr("expects parent computation thread name same as called "
2448 "computation's thread name"));
2449 }
2450
TEST_F(HloVerifierTest,CheckWhileThread)2451 TEST_F(HloVerifierTest, CheckWhileThread) {
2452 const char* const hlo_string = R"(
2453 HloModule While, entry_computation_layout={()->s32[]}
2454
2455 %body.v3 (prev.1: s32[]) -> s32[] {
2456 %constant = s32[] constant(1)
2457 %prev.1 = s32[] parameter(0)
2458 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
2459 }
2460
2461 %condition.v3 (prev.2: s32[]) -> pred[] {
2462 %constant.1 = s32[] constant(5)
2463 %prev.2 = s32[] parameter(0)
2464 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
2465 }, execution_thread="parallel_thread"
2466
2467 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
2468 %constant.2 = s32[] constant(0)
2469 ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
2470 }
2471 )";
2472 TF_ASSERT_OK_AND_ASSIGN(auto module,
2473 ParseAndReturnUnverifiedModule(hlo_string));
2474 auto status = verifier().Run(module.get()).status();
2475 ASSERT_FALSE(status.ok());
2476 EXPECT_THAT(status.error_message(),
2477 HasSubstr("expects parent computation thread name same as called "
2478 "computation's thread name"));
2479 }
2480
TEST_F(HloVerifierTest,CheckWhileContainsAsyncThread)2481 TEST_F(HloVerifierTest, CheckWhileContainsAsyncThread) {
2482 const char* const hlo_string = R"(
2483 HloModule While, entry_computation_layout={()->s32[]}
2484
2485 %async_add (prev.1: s32[]) -> s32[] {
2486 %constant = s32[] constant(1)
2487 %prev.1 = s32[] parameter(0)
2488 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
2489 }, execution_thread="parallel_thread"
2490
2491 %body.v3 (prev.1: s32[]) -> s32[] {
2492 %constant = s32[] constant(1)
2493 %prev.1 = s32[] parameter(0)
2494 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
2495 }
2496
2497 %condition.v3 (prev.2: s32[]) -> pred[] {
2498 %constant.1 = s32[] constant(5)
2499 %prev.2 = s32[] parameter(0)
2500 %async-start = ((s32[]), s32[], s32[]) custom-call-start(s32[] %prev.2), async_execution_thread="parallel_thread", custom_call_target="async_add"
2501 %async-done = s32[] custom-call-done(((s32[]), s32[], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="async_add"
2502 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %async-done), direction=GT
2503 }
2504
2505 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
2506 %constant.2 = s32[] constant(0)
2507 ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
2508 }
2509 )";
2510 TF_ASSERT_OK_AND_ASSIGN(auto module,
2511 ParseAndReturnUnverifiedModule(hlo_string));
2512 auto status = verifier().Run(module.get()).status();
2513 ASSERT_TRUE(status.ok());
2514 }
2515
TEST_F(HloVerifierTestLayoutFusion,DynamicUpdateSliceWithMemorySpace)2516 TEST_F(HloVerifierTestLayoutFusion, DynamicUpdateSliceWithMemorySpace) {
2517 const char* const hlo_string = R"(
2518 HloModule fusion, is_scheduled=true
2519
2520 fused_computation {
2521 %parameter.0 = bf16[1,8,1,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(0)
2522 %parameter.1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(1)
2523 %c = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)} copy(parameter.1)
2524 %constant.1 = s32[] constant(0)
2525 ROOT %dynamic-update-slice.1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)}
2526 dynamic-update-slice(%c, %parameter.0, %constant.1, %constant.1,
2527 %constant.1, %constant.1, %constant.1)
2528 }
2529
2530 ENTRY entry (parameter.0: bf16[1,8,1,8,320], parameter.1: bf16[1,8,6,8,320]) -> bf16[1,8,6,8,320]{
2531 %p0 = bf16[1,8,1,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(0)
2532 %p1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(1)
2533 ROOT out = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} fusion(p0, p1), kind=kLoop, calls=fused_computation
2534 })";
2535 TF_ASSERT_OK_AND_ASSIGN(auto module,
2536 ParseAndReturnUnverifiedModule(hlo_string));
2537
2538 auto status = verifier().Run(module.get()).status();
2539 ASSERT_TRUE(status.ok());
2540 }
2541
2542 } // namespace
2543 } // namespace xla
2544