1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_replace.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/layout_assignment.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla.pb.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37 namespace {
38
39 using ::testing::HasSubstr;
40
CreateUnverifiedModule()41 std::unique_ptr<HloModule> CreateUnverifiedModule() {
42 return absl::make_unique<HloModule>("module", HloModuleConfig());
43 }
44
45 // This class cannot be converted to use HloTestBase. It explicitly
46 // uses HloTestBase to create and test malformed HLOs.
47 class HloVerifierTest : public HloTestBase {
48 public:
HloVerifierTest()49 HloVerifierTest()
50 : HloTestBase(/*verifier_layout_sensitive=*/false,
51 /*allow_mixed_precision_in_hlo_verifier=*/false) {}
52 };
53
54 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
55 public:
HloVerifierTestAllowMixedPrecision()56 HloVerifierTestAllowMixedPrecision()
57 : HloTestBase(/*verifier_layout_sensitive=*/false,
58 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
59 };
60
61 class HloVerifierTestLayoutSensitive : public HloTestBase {
62 public:
HloVerifierTestLayoutSensitive()63 HloVerifierTestLayoutSensitive()
64 : HloTestBase(/*verifier_layout_sensitive=*/true,
65 /*allow_mixed_precision_in_hlo_verifier=*/false,
66 LayoutAssignment::InstructionCanChangeLayout) {}
67 };
68
TEST_F(HloVerifierTest,NullInstructionParent)69 TEST_F(HloVerifierTest, NullInstructionParent) {
70 HloComputation::Builder builder(TestName());
71 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
72 HloInstruction* param = builder.AddInstruction(
73 HloInstruction::CreateParameter(0, scalar_shape, "param"));
74 HloInstruction* negate = builder.AddInstruction(
75 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
76 auto module = CreateUnverifiedModule();
77 module->AddEntryComputation(builder.Build());
78
79 TF_ASSERT_OK(verifier().Run(module.get()).status());
80
81 negate->set_parent(nullptr);
82
83 auto status = verifier().Run(module.get()).status();
84 ASSERT_FALSE(status.ok());
85 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
86 }
87
TEST_F(HloVerifierTest,NullComputationParent)88 TEST_F(HloVerifierTest, NullComputationParent) {
89 HloComputation::Builder builder(TestName());
90 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
91 HloInstruction* param = builder.AddInstruction(
92 HloInstruction::CreateParameter(0, scalar_shape, "param"));
93 builder.AddInstruction(
94 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
95 auto module = CreateUnverifiedModule();
96 HloComputation* computation = module->AddEntryComputation(builder.Build());
97
98 TF_ASSERT_OK(verifier().Run(module.get()).status());
99
100 computation->set_parent(nullptr);
101
102 auto status = verifier().Run(module.get()).status();
103 ASSERT_FALSE(status.ok());
104 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
105 }
106
TEST_F(HloVerifierTest,DifferentOperandParents)107 TEST_F(HloVerifierTest, DifferentOperandParents) {
108 HloComputation::Builder builder(TestName());
109 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
110 HloInstruction* param = builder.AddInstruction(
111 HloInstruction::CreateParameter(0, scalar_shape, "param"));
112 HloInstruction* negate = builder.AddInstruction(
113 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
114 auto module = CreateUnverifiedModule();
115 module->AddEntryComputation(builder.Build());
116
117 HloComputation::Builder emb_builder(TestName());
118 HloInstruction* emb_param = emb_builder.AddInstruction(
119 HloInstruction::CreateParameter(0, scalar_shape, "param"));
120 module->AddEmbeddedComputation(emb_builder.Build());
121
122 TF_ASSERT_OK(verifier().Run(module.get()).status());
123 TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
124
125 auto status = verifier().Run(module.get()).status();
126 ASSERT_FALSE(status.ok());
127 EXPECT_THAT(status.error_message(),
128 HasSubstr("is in a different computation"));
129 }
130
TEST_F(HloVerifierTest,ResetsShapeVerifierState)131 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
132 HloComputation::Builder builder(TestName());
133 Shape s1 = ShapeUtil::MakeShape(F32, {1});
134 Shape s2 = ShapeUtil::MakeShape(F32, {2});
135
136 HloInstruction* param =
137 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
138
139 // Create an add instruction with the incorrect shape.
140 HloInstruction* add = builder.AddInstruction(
141 HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
142
143 // In order to trigger the bug we're checking for, the instruction with the
144 // bad shape can't be the root of the computation.
145 builder.AddInstruction(
146 HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
147
148 auto module = CreateUnverifiedModule();
149 module->AddEntryComputation(builder.Build());
150
151 // Run the verifier twice. It should fail both times, because it shouldn't
152 // carry state in its DFS visitor between runs.
153 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
154 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
155 }
156
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)157 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
158 const char* const hlo_string = R"(
159 HloModule Module
160
161 callme {
162 ROOT param = (s32[], f32[4]) parameter(0)
163 }
164
165 ENTRY entry {
166 p0 = (f32[4], s32[]) parameter(0)
167 ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
168 }
169 )";
170 TF_ASSERT_OK_AND_ASSIGN(auto module,
171 ParseAndReturnUnverifiedModule(hlo_string));
172
173 auto status = verifier().Run(module.get()).status();
174 ASSERT_FALSE(status.ok());
175 EXPECT_THAT(status.error_message(),
176 HasSubstr("shape does not match parameter"));
177 }
178
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)179 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
180 const char* const hlo_string = R"(
181 HloModule Module
182
183 true_branch {
184 tparam = (s32[], f32[4]) parameter(0)
185 ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
186 }
187
188 false_branch {
189 fparam = (s32[], f32[4]) parameter(0)
190 ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
191 }
192
193 ENTRY entry {
194 p0 = (f32[4], s32[]) parameter(0)
195 constant = pred[] constant(true)
196 ROOT conditional = f32[4] conditional(constant, p0, p0),
197 true_computation=true_branch, false_computation=false_branch
198 }
199 )";
200 TF_ASSERT_OK_AND_ASSIGN(auto module,
201 ParseAndReturnUnverifiedModule(hlo_string));
202
203 auto status = verifier().Run(module.get()).status();
204 ASSERT_FALSE(status.ok());
205 EXPECT_THAT(status.error_message(),
206 HasSubstr("shape does not match parameter"));
207 }
208
TEST_F(HloVerifierTest,CheckConditionalBranchIndexOperandShape)209 TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
210 const char* const hlo_string = R"(
211 HloModule Module
212
213 branch0 {
214 tparam = f32[4] parameter(0)
215 ROOT tgte1 = f32[4] ceil(tparam)
216 }
217
218 branch1 {
219 fparam = f32[4] parameter(0)
220 ROOT fgte1 = f32[4] floor(fparam)
221 }
222
223 branch2 {
224 sparam = f32[4] parameter(0)
225 ROOT sgte1 = f32[4] ceil(sparam)
226 }
227
228 ENTRY entry {
229 p0 = f32[4] parameter(0)
230 b0 = s32[] parameter(1)
231 ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
232 branch_computations={branch0, branch1, branch2}
233 }
234 )";
235 TF_ASSERT_OK_AND_ASSIGN(auto module,
236 ParseAndReturnUnverifiedModule(hlo_string));
237 auto status = verifier().Run(module.get()).status();
238
239 HloInstruction* condition = FindInstruction(module.get(), "b0");
240 *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
241 status = verifier().Run(module.get()).status();
242 ASSERT_FALSE(status.ok());
243 EXPECT_THAT(
244 status.error_message(),
245 HasSubstr(
246 "first operand of indexed conditional must be a scalar of S32"));
247
248 *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
249 status = verifier().Run(module.get()).status();
250 ASSERT_FALSE(status.ok());
251 EXPECT_THAT(status.error_message(),
252 HasSubstr("first operand of conditional must be a scalar"));
253 }
254
TEST_F(HloVerifierTest,RngOpnd0NotScalar)255 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
256 const char* const hlo_string = R"(
257 HloModule Module
258
259 ENTRY RngOpnd0NotScalar {
260 constant.0 = f32[] constant(0)
261 constant.1 = f16[2] constant({1, 3})
262 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
263 distribution=rng_uniform
264 }
265 )";
266 TF_ASSERT_OK_AND_ASSIGN(auto module,
267 ParseAndReturnUnverifiedModule(hlo_string));
268
269 auto status = verifier().Run(module.get()).status();
270 ASSERT_FALSE(status.ok());
271 EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
272 }
273
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)274 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
275 const char* const hlo_string = R"(
276 HloModule Module
277
278 ENTRY RngOperandElementTypesNotMatch {
279 constant.0 = f32[] constant(0)
280 constant.1 = f16[] constant(1)
281 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
282 distribution=rng_normal
283 }
284 )";
285 TF_ASSERT_OK_AND_ASSIGN(auto module,
286 ParseAndReturnUnverifiedModule(hlo_string));
287
288 auto status = verifier().Run(module.get()).status();
289 ASSERT_FALSE(status.ok());
290 EXPECT_THAT(status.error_message(),
291 HasSubstr("Expected compatible element types"));
292 }
293
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)294 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
295 const char* const hlo_string = R"(
296 HloModule Module
297
298 ENTRY RngResultElementTypeNotMatch {
299 constant.0 = f32[] constant(0)
300 constant.1 = f32[] constant(1)
301 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
302 distribution=rng_normal
303 }
304 )";
305 TF_ASSERT_OK_AND_ASSIGN(auto module,
306 ParseAndReturnUnverifiedModule(hlo_string));
307
308 auto status = verifier().Run(module.get()).status();
309 ASSERT_FALSE(status.ok());
310 EXPECT_THAT(status.error_message(),
311 HasSubstr("Expected compatible element types"));
312 }
313
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)314 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
315 const char* const hlo_string = R"(
316 HloModule Module
317
318 ENTRY RngResultElementTypeNotMatch {
319 constant.0 = f32[] constant(0)
320 constant.1 = f32[] constant(1)
321 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
322 distribution=rng_normal
323 }
324 )";
325 TF_ASSERT_OK_AND_ASSIGN(auto module,
326 ParseAndReturnVerifiedModule(hlo_string));
327
328 auto status = verifier().Run(module.get()).status();
329 ASSERT_TRUE(status.ok());
330 }
331
TEST_F(HloVerifierTest,RngElementTypeNotSupported)332 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
333 const char* const hlo_string = R"(
334 HloModule Module
335
336 ENTRY RngElementTypeNotSupported {
337 constant.0 = s32[] constant(0)
338 constant.1 = s32[] constant(1)
339 ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
340 distribution=rng_normal
341 }
342 )";
343 TF_ASSERT_OK_AND_ASSIGN(auto module,
344 ParseAndReturnUnverifiedModule(hlo_string));
345
346 auto status = verifier().Run(module.get()).status();
347 ASSERT_FALSE(status.ok());
348 EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
349 }
350
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)351 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
352 // This testcase can't be written using textual HLO, because it doesn't parse
353 // negative interior padding. That's probably a feature. :)
354 HloComputation::Builder builder(TestName());
355 HloInstruction* param =
356 builder.AddInstruction(HloInstruction::CreateParameter(
357 0, ShapeUtil::MakeShape(F32, {100}), "param"));
358 PaddingConfig padding_config;
359 padding_config.add_dimensions()->set_interior_padding(-1);
360 builder.AddInstruction(HloInstruction::CreatePad(
361 ShapeUtil::MakeShape(F32, {100}), param,
362 builder.AddInstruction(
363 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
364 padding_config));
365
366 auto module = CreateUnverifiedModule();
367 module->AddEntryComputation(builder.Build());
368
369 auto status = verifier().Run(module.get()).status();
370 ASSERT_FALSE(status.ok());
371 EXPECT_THAT(status.error_message(),
372 HasSubstr("Interior padding cannot be negative"));
373 }
374
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)375 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
376 // This testcase can't be written using textual HLO, because it doesn't parse
377 // negative interior padding. That's probably a feature. :)
378 HloComputation::Builder builder(TestName());
379 HloInstruction* param =
380 builder.AddInstruction(HloInstruction::CreateParameter(
381 0, ShapeUtil::MakeShape(F32, {100}), "param"));
382 PaddingConfig padding_config;
383 padding_config.add_dimensions()->set_interior_padding(-1);
384 builder.AddInstruction(HloInstruction::CreatePad(
385 ShapeUtil::MakeShape(F32, {100}), param,
386 builder.AddInstruction(
387 HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
388 padding_config));
389
390 auto module = CreateUnverifiedModule();
391 module->AddEntryComputation(builder.Build());
392
393 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
394 HasSubstr("Interior padding cannot be negative"));
395 }
396
TEST_F(HloVerifierTest,DotMixedPrecisionAllowed)397 TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
398 static const char* const kDotHloString = R"(
399 HloModule module
400 ENTRY entry_computation {
401 a = f32[2,10] parameter(0)
402 b = bf16[10,2] parameter(1)
403 ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
404 })";
405 TF_ASSERT_OK_AND_ASSIGN(auto module,
406 ParseAndReturnVerifiedModule(kDotHloString));
407
408 auto status = verifier().Run(module.get()).status();
409 EXPECT_TRUE(status.ok()) << status;
410 }
411
412 // Simple module containing a convolution as the root.
413 static const char* const kConvHloString = R"(
414 HloModule module
415 ENTRY entry_computation {
416 param0 = f16[128,128,56,56] parameter(0)
417 param1 = f16[3,3,128,128] parameter(1)
418 zero_f16 = f16[] constant(0)
419 ROOT conv = f16[128,128,28,28] convolution(param0, param1),
420 window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
421 })";
422
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)423 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
424 TF_ASSERT_OK_AND_ASSIGN(auto module,
425 ParseAndReturnUnverifiedModule(kConvHloString));
426 auto* conv = module->entry_computation()->root_instruction();
427 Window w = conv->window();
428 w.mutable_dimensions(0)->set_window_dilation(-1);
429 conv->set_window(w);
430
431 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
432 HasSubstr("non-positive window dilation factor"));
433 }
434
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)435 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
436 TF_ASSERT_OK_AND_ASSIGN(auto module,
437 ParseAndReturnUnverifiedModule(kConvHloString));
438 auto* conv = module->entry_computation()->root_instruction();
439 Window w = conv->window();
440 w.mutable_dimensions(0)->set_base_dilation(-1);
441 conv->set_window(w);
442
443 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
444 HasSubstr("non-positive base area dilation factor"));
445 }
446
447 static const char* const kAddWithLayoutChangeHlo = R"(
448 HloModule AddWithLayoutChange
449 ENTRY AddWithLayoutChange {
450 par0 = f32[3,4]{1,0} parameter(0)
451 par1 = f32[3,4]{0,1} parameter(1)
452 ROOT add0 = f32[3,4]{1,0} add(par0,par1)
453 }
454 )";
455
TEST_F(HloVerifierTest,AddWithLayoutChange)456 TEST_F(HloVerifierTest, AddWithLayoutChange) {
457 TF_ASSERT_OK_AND_ASSIGN(
458 auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
459 auto status = verifier().Run(module.get()).status();
460 ASSERT_TRUE(status.ok());
461 }
462
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)463 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
464 const char* const kScalarIndexDynamicSlice = R"(
465 HloModule DynamicSlice_module
466
467 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
468 %original_parameter = s32[2,2,258] parameter(0)
469 %constant = s32[] constant(0)
470 %start_index = s32[] parameter(1)
471 ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
472 }
473 )";
474
475 HloModuleConfig config;
476 DebugOptions debug_options = config.debug_options();
477 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
478 config.set_debug_options(debug_options);
479
480 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
481 kScalarIndexDynamicSlice, config));
482 auto status = verifier().Run(module.get()).status();
483 ASSERT_TRUE(status.ok());
484 }
485
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)486 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
487 const char* const kScalarIndexDynamicSlice = R"(
488 HloModule DynamicUpdateSlice_module
489
490 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
491 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
492 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
493 %start_index.0 = s32[] parameter(2)
494 %start_index.1 = s32[] parameter(3)
495 %start_index.2 = s32[] parameter(4)
496 %start_index.3 = s32[] parameter(5)
497 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
498 }
499 )";
500
501 HloModuleConfig config;
502 DebugOptions debug_options = config.debug_options();
503 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
504 config.set_debug_options(debug_options);
505
506 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
507 kScalarIndexDynamicSlice, config));
508 auto status = verifier().Run(module.get()).status();
509 ASSERT_TRUE(status.ok());
510 }
511
TEST_F(HloVerifierTestAllowMixedPrecision,DynamicUpdateSliceMixedPrecision)512 TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) {
513 const char* const kDynamicUpdateSliceMixedPrecision = R"(
514 HloModule kDynamicUpdateSliceMixedPrecision
515
516 ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] {
517 %parameter.0 = f32[32,511,2048] parameter(0)
518 %parameter.1 = bf16[32,511,512] parameter(1)
519 %parameter.2 = s32[] parameter(2)
520 %parameter.3 = s32[] parameter(3)
521 %parameter.4 = s32[] parameter(4)
522 ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4)
523 }
524 )";
525 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
526 kDynamicUpdateSliceMixedPrecision));
527 auto status = verifier().Run(module.get()).status();
528 ASSERT_FALSE(status.ok());
529 EXPECT_THAT(status.error_message(),
530 HasSubstr("Expected instruction to have shape equal to "
531 "f32[32,511,2048], actual shape is bf16[32,511,2048]"));
532 }
533
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)534 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
535 TF_ASSERT_OK_AND_ASSIGN(
536 auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
537 auto status = verifier().Run(module.get()).status();
538 ASSERT_FALSE(status.ok());
539 EXPECT_THAT(status.error_message(),
540 HasSubstr("Instruction shouldn't change layouts"));
541 }
542
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)543 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
544 const char* const kSliceWithLayoutChangeHlo = R"(
545 HloModule SliceWithLayoutChange
546 ENTRY SliceWithLayoutChange {
547 par0 = f32[4,5]{0,1} parameter(0)
548 par1 = s32[] parameter(1)
549 par2 = s32[] parameter(2)
550 ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
551 dynamic_slice_sizes={3,4}
552 }
553 )";
554 TF_ASSERT_OK_AND_ASSIGN(
555 auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
556 auto status = verifier().Run(module.get()).status();
557 ASSERT_FALSE(status.ok());
558 EXPECT_THAT(status.error_message(),
559 HasSubstr("Instruction shouldn't change layouts"));
560 }
561
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)562 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
563 const char* const kConcatWithLayoutChangeHlo = R"(
564 HloModule ConcatWithLayoutChange
565 ENTRY ConcatWithLayoutChange {
566 par0 = f32[3,5]{0,1} parameter(0)
567 par1 = f32[3,3]{1,0} parameter(1)
568 ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
569 dimensions={1}
570 }
571 )";
572 TF_ASSERT_OK_AND_ASSIGN(
573 auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
574 auto status = verifier().Run(module.get()).status();
575 ASSERT_FALSE(status.ok());
576 EXPECT_THAT(status.error_message(),
577 HasSubstr("Instruction shouldn't change layouts"));
578 }
579
TEST_F(HloVerifierTestLayoutSensitive,BitcastNeedsSameNumberOfElements)580 TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
581 const char* const hlo_string = R"(
582 HloModule Module
583
584 ENTRY BitcastNeedsToBeNoOp {
585 constant.0 = f32[2] constant({0.0, 0.0})
586 ROOT bitcast = f32[3] bitcast(constant.0)
587 }
588 )";
589 TF_ASSERT_OK_AND_ASSIGN(auto module,
590 ParseAndReturnUnverifiedModule(hlo_string));
591
592 auto status = verifier().Run(module.get()).status();
593 ASSERT_FALSE(status.ok());
594 EXPECT_THAT(status.error_message(),
595 HasSubstr("Bitcast cannot have different shape sizes of output "
596 "(12) and operand (8)"));
597 }
598
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)599 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
600 const char* const hlo_string = R"(
601 HloModule Module
602
603 ENTRY SelectMixedPrecisionNotAllowed {
604 p0 = pred[32] parameter(0)
605 p1 = f32[32] parameter(1)
606 p2 = bf16[32] parameter(2)
607 ROOT select = f32[32] select(p0, p1, p2)
608 }
609 )";
610 TF_ASSERT_OK_AND_ASSIGN(auto module,
611 ParseAndReturnUnverifiedModule(hlo_string));
612
613 auto status = verifier().Run(module.get()).status();
614 ASSERT_FALSE(status.ok());
615 EXPECT_THAT(status.error_message(),
616 HasSubstr("Seen floating point types of different precisions"));
617 }
618
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)619 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
620 const char* const hlo_string = R"(
621 HloModule Module
622
623 ENTRY SelectMixedPrecisionAllowed {
624 p0 = pred[32] parameter(0)
625 p1 = f32[32] parameter(1)
626 p2 = bf16[32] parameter(2)
627 ROOT select = f32[32] select(p0, p1, p2)
628 }
629 )";
630 TF_ASSERT_OK_AND_ASSIGN(auto module,
631 ParseAndReturnVerifiedModule(hlo_string));
632
633 auto status = verifier().Run(module.get()).status();
634 ASSERT_TRUE(status.ok());
635 }
636
TEST_F(HloVerifierTest,SelectTupleNotAllowed)637 TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
638 const char* const hlo_string = R"(
639 HloModule Module
640
641 ENTRY SelectWithTuple {
642 p0 = (f32[], f32[]) parameter(0)
643 p1 = (f32[], f32[]) parameter(1)
644 p2 = pred[] parameter(2)
645 ROOT select = (f32[], f32[]) select(p2, p0, p1)
646 }
647 )";
648 TF_ASSERT_OK_AND_ASSIGN(auto module,
649 ParseAndReturnUnverifiedModule(hlo_string));
650
651 auto status = verifier().Run(module.get()).status();
652 ASSERT_FALSE(status.ok());
653 EXPECT_THAT(status.error_message(),
654 HasSubstr("Expected array argument for select"));
655 }
656
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDone)657 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
658 const char* const hlo_string = R"(
659 HloModule Module
660
661 ENTRY CopyStartAndCopyDone {
662 p0 = f32[2,3]{1,0:S(1)} parameter(0)
663 copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
664 ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
665 }
666 )";
667 TF_ASSERT_OK_AND_ASSIGN(auto module,
668 ParseAndReturnVerifiedModule(hlo_string));
669
670 auto status = verifier().Run(module.get()).status();
671 ASSERT_TRUE(status.ok());
672 }
673
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDoneWrongLayout)674 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
675 const char* const hlo_string = R"(
676 HloModule Module
677
678 ENTRY CopyStartAndCopyDone {
679 p0 = f32[2,3]{1,0:S(1)} parameter(0)
680 copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
681 ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
682 }
683 )";
684 TF_ASSERT_OK_AND_ASSIGN(auto module,
685 ParseAndReturnUnverifiedModule(hlo_string));
686
687 auto status = verifier().Run(module.get()).status();
688 ASSERT_FALSE(status.ok());
689 EXPECT_THAT(status.error_message(),
690 HasSubstr("Expected instruction to have shape equal to"));
691 }
692
TEST_F(HloVerifierTest,CopyStartAndCopyDoneWrongType)693 TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
694 const char* const hlo_string = R"(
695 HloModule Module
696
697 ENTRY CopyStartAndCopyDone {
698 p0 = f32[2,3] parameter(0)
699 copy-start = f32[2,3] copy-start(p0)
700 ROOT copy-done = f32[2,3] copy-done(copy-start)
701 }
702 )";
703 TF_ASSERT_OK_AND_ASSIGN(auto module,
704 ParseAndReturnUnverifiedModule(hlo_string));
705
706 auto status = verifier().Run(module.get()).status();
707 ASSERT_FALSE(status.ok());
708 EXPECT_THAT(status.error_message(),
709 HasSubstr("Expected instruction to have shape equal to "
710 "(f32[2,3], f32[2,3], u32[])"));
711 }
712
TEST_F(HloVerifierTest,CopyStartMultipleCopyDone)713 TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
714 const char* const hlo_string = R"(
715 HloModule Module
716
717 ENTRY CopyStartAndCopyDone {
718 p0 = f32[2,3] parameter(0)
719 copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
720 copy-done.1 = f32[2,3] copy-done(copy-start)
721 copy-done.2 = f32[2,3] copy-done(copy-start)
722 ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
723 }
724 )";
725 TF_ASSERT_OK_AND_ASSIGN(auto module,
726 ParseAndReturnUnverifiedModule(hlo_string));
727
728 auto status = verifier().Run(module.get()).status();
729 ASSERT_FALSE(status.ok());
730 EXPECT_THAT(
731 status.error_message(),
732 HasSubstr("copy-start instruction requires one consumer, found 2"));
733 }
734
TEST_F(HloVerifierTest,CopyDoneNoCopyStart)735 TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
736 const char* const hlo_string = R"(
737 HloModule Module
738
739 ENTRY CopyStartAndCopyDone {
740 p0 = f32[2,3] parameter(0)
741 p1 = u32[] parameter(1)
742 tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
743 ROOT copy-done = f32[2,3] copy-done(tuple)
744 }
745 )";
746 TF_ASSERT_OK_AND_ASSIGN(auto module,
747 ParseAndReturnUnverifiedModule(hlo_string));
748
749 auto status = verifier().Run(module.get()).status();
750 ASSERT_FALSE(status.ok());
751 EXPECT_THAT(status.error_message(),
752 HasSubstr("The operand of a copy-done instruction needs to be "
753 "copy-start, found tuple"));
754 }
755
TEST_F(HloVerifierTest,IotaNonArrayResult)756 TEST_F(HloVerifierTest, IotaNonArrayResult) {
757 const char* const hlo_string = R"(
758 HloModule IotaTupleResult
759
760 ENTRY kernelEntry {
761 ROOT iota = () iota(), iota_dimension=24
762 }
763 )";
764
765 TF_ASSERT_OK_AND_ASSIGN(auto module,
766 ParseAndReturnUnverifiedModule(hlo_string));
767
768 auto status = verifier().Run(module.get()).status();
769 ASSERT_FALSE(status.ok());
770 EXPECT_THAT(status.error_message(),
771 HasSubstr("does not support non-array result"));
772 }
773
TEST_F(HloVerifierTest,IotaNegativeDimension)774 TEST_F(HloVerifierTest, IotaNegativeDimension) {
775 const char* const hlo_string = R"(
776 HloModule IotaTupleResult
777
778 ENTRY kernelEntry {
779 ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
780 }
781 )";
782
783 TF_ASSERT_OK_AND_ASSIGN(auto module,
784 ParseAndReturnUnverifiedModule(hlo_string));
785
786 auto status = verifier().Run(module.get()).status();
787 ASSERT_FALSE(status.ok());
788 EXPECT_THAT(status.error_message(), HasSubstr("negative"));
789 }
790
TEST_F(HloVerifierTest,IotaPredResultNotAllowed)791 TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
792 const char* const hlo_string = R"(
793 HloModule IotaPredResult
794
795 ENTRY kernelEntry {
796 ROOT iota = pred[128] iota(), iota_dimension=0
797 }
798 )";
799
800 TF_ASSERT_OK_AND_ASSIGN(auto module,
801 ParseAndReturnUnverifiedModule(hlo_string));
802
803 auto status = verifier().Run(module.get()).status();
804 ASSERT_FALSE(status.ok());
805 EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
806 }
807
808 static const char* const kMapOperandComputationMismatchHlo = R"(
809 HloModule MapOperandComputationMismatch
810
811 Computation {
812 param0 = f32[] parameter(0)
813 constant = f32[] constant(1)
814 ROOT add = f32[] add(param0, constant)
815 }
816
817 ENTRY kernelEntry {
818 param = f64[] parameter(0)
819 ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
820 })";
821
TEST_F(HloVerifierTest,MapOperandComputationMismatch)822 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
823 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
824 kMapOperandComputationMismatchHlo));
825 auto status = verifier().Run(module.get()).status();
826 ASSERT_FALSE(status.ok());
827 EXPECT_THAT(
828 status.error_message(),
829 HasSubstr(
830 "Shape mismatch between to_apply computation parameter and operand"));
831 }
832
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)833 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
834 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
835 kMapOperandComputationMismatchHlo));
836 auto status = verifier().Run(module.get()).status();
837 ASSERT_TRUE(status.ok());
838 }
839
840 static const char* const kReduceOperandComputationMismatchHlo = R"(
841 HloModule ReduceOperandComputationMismatch
842 computation {
843 x = f32[] parameter(0)
844 y = f32[] parameter(1)
845 ROOT add = f32[] add(x, y)
846 }
847
848 ENTRY kernelEntry {
849 arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
850 constant = f16[] constant(0)
851 reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
852 })";
853
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)854 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
855 TF_ASSERT_OK_AND_ASSIGN(
856 auto module,
857 ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
858 auto status = verifier().Run(module.get()).status();
859 ASSERT_FALSE(status.ok());
860 EXPECT_THAT(status.error_message(),
861 HasSubstr("Expected instruction to have shape equal to f32[64]"));
862 }
863
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)864 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
865 TF_ASSERT_OK_AND_ASSIGN(
866 auto module,
867 ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
868 auto status = verifier().Run(module.get()).status();
869 ASSERT_TRUE(status.ok());
870 }
871
ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups)872 string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
873 std::vector<string> replica_group_strs;
874 for (const auto& g : replica_groups) {
875 replica_group_strs.push_back(
876 absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
877 }
878 return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
879 }
880
ReplicaCount(const std::vector<std::vector<int64>> & replica_groups)881 int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
882 int64_t replica_count = 0;
883 for (auto group : replica_groups) {
884 replica_count += group.size();
885 }
886 return replica_count;
887 }
888
MakeCollectiveCommOpComputation(std::vector<std::vector<int64>> replica_groups,absl::optional<int64> replica_count,absl::optional<int64> num_partitions,absl::string_view other_attributes,absl::string_view template_str)889 StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
890 std::vector<std::vector<int64>> replica_groups,
891 absl::optional<int64> replica_count, absl::optional<int64> num_partitions,
892 absl::string_view other_attributes, absl::string_view template_str) {
893 HloModuleConfig config;
894 config.set_replica_count(
895 replica_count.value_or(ReplicaCount(replica_groups)));
896 config.set_num_partitions(num_partitions.value_or(1));
897 return ParseAndReturnUnverifiedModule(
898 absl::StrReplaceAll(
899 template_str,
900 {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
901 {"OTHER_ATTRIBUTES", other_attributes.empty()
902 ? ""
903 : absl::StrCat(",", other_attributes)}}),
904 config);
905 }
906
MakeAllReduceComputation(std::vector<std::vector<int64>> replica_groups,absl::optional<int64> replica_count=absl::nullopt,absl::optional<int64> num_partitions=absl::nullopt,absl::string_view other_attributes="")907 StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
908 std::vector<std::vector<int64>> replica_groups,
909 absl::optional<int64> replica_count = absl::nullopt,
910 absl::optional<int64> num_partitions = absl::nullopt,
911 absl::string_view other_attributes = "") {
912 const char* kTemplate = R"(
913 HloModule test
914 add {
915 x = f32[] parameter(0)
916 y = f32[] parameter(1)
917 ROOT add = f32[] add(x, y)
918 }
919 ENTRY entry {
920 p = f32[128]{0} parameter(0)
921 crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
922 OTHER_ATTRIBUTES
923 })";
924 return MakeCollectiveCommOpComputation(replica_groups, replica_count,
925 num_partitions, other_attributes,
926 kTemplate);
927 }
928
TEST_F(HloVerifierTest,AllReduce_NoReplicaGroupsOK)929 TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
930 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
931 TF_ASSERT_OK(verifier().Run(module.get()).status());
932 }
933
TEST_F(HloVerifierTest,AllReduce_DifferentGroupSizesOk)934 TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
935 TF_ASSERT_OK_AND_ASSIGN(auto module,
936 MakeAllReduceComputation({{0}, {1, 3}, {2}}));
937 TF_ASSERT_OK(verifier().Run(module.get()).status());
938 }
939
TEST_F(HloVerifierTest,AllReduce_EmptyReplicaGroup)940 TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
941 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
942 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
943 HasSubstr("empty replica group"));
944 }
945
TEST_F(HloVerifierTest,AllReduce_RepeatedReplicaId)946 TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
947 TF_ASSERT_OK_AND_ASSIGN(auto module,
948 MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
949 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
950 HasSubstr("Replica 0 is repeated"));
951 }
952
TEST_F(HloVerifierTest,AllReduce_MissingReplicaId)953 TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
954 TF_ASSERT_OK_AND_ASSIGN(auto module,
955 MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
956 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
957 HasSubstr("Replica 4 is not named"));
958 }
959
TEST_F(HloVerifierTest,AllReduce_NotEnougReplicasInGroupConfig)960 TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
961 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
962 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
963 HasSubstr("In kCrossReplica mode, replica groups should contain "
964 "8 replicas, but found 2"));
965 }
966
TEST_F(HloVerifierTest,AllReduce_TooManyReplicasInGroupConfig)967 TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
968 TF_ASSERT_OK_AND_ASSIGN(auto module,
969 MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
970 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
971 HasSubstr("In kCrossReplica mode, replica groups should contain "
972 "2 replicas, but found 4"));
973 }
974
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Invalid)975 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
976 TF_ASSERT_OK_AND_ASSIGN(
977 auto module,
978 MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
979 EXPECT_THAT(
980 verifier().Run(module.get()).status().error_message(),
981 HasSubstr(
982 "In kCrossReplicaAndPartition mode, replica groups should contain "
983 "2 replicas, but found 4"));
984 }
985
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Valid)986 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
987 TF_ASSERT_OK_AND_ASSIGN(
988 auto module,
989 MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
990 TF_ASSERT_OK(verifier().Run(module.get()).status());
991 }
992
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Invalid)993 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
994 TF_ASSERT_OK_AND_ASSIGN(
995 auto module,
996 MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
997 "channel_id=1, use_global_device_ids=true"));
998 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
999 HasSubstr("In kFlattenedID mode, replica groups should contain "
1000 "2 flattened IDs, but found 4"));
1001 }
1002
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Valid)1003 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
1004 TF_ASSERT_OK_AND_ASSIGN(
1005 auto module,
1006 MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
1007 "channel_id=1, use_global_device_ids=true"));
1008 TF_ASSERT_OK(verifier().Run(module.get()).status());
1009 }
1010
TEST_F(HloVerifierTest,AllReduceStartAndDone)1011 TEST_F(HloVerifierTest, AllReduceStartAndDone) {
1012 const char* const kModuleStr = R"(
1013 HloModule test
1014 add {
1015 x = f32[] parameter(0)
1016 y = f32[] parameter(1)
1017 ROOT add = f32[] add(x, y)
1018 }
1019 ENTRY entry {
1020 p0 = f32[2,3] parameter(0)
1021 start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1022 ROOT done = f32[2,3] all-reduce-done(start)
1023 }
1024 )";
1025 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1026 ParseAndReturnUnverifiedModule(kModuleStr));
1027
1028 auto status = verifier().Run(module.get()).status();
1029 ASSERT_TRUE(status.ok());
1030 }
1031
TEST_F(HloVerifierTest,AllReduceStartAndDoneWrongType)1032 TEST_F(HloVerifierTest, AllReduceStartAndDoneWrongType) {
1033 const char* const kModuleStr = R"(
1034 HloModule test
1035 add {
1036 x = f32[] parameter(0)
1037 y = f32[] parameter(1)
1038 ROOT add = f32[] add(x, y)
1039 }
1040 ENTRY entry {
1041 p0 = f32[2,3] parameter(0)
1042 start = f32[2,3] all-reduce-start(p0), to_apply=add
1043 ROOT done = f32[2,3] all-reduce-done(start)
1044 }
1045 )";
1046 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1047 ParseAndReturnUnverifiedModule(kModuleStr));
1048
1049 auto status = verifier().Run(module.get()).status();
1050 EXPECT_THAT(status.error_message(),
1051 HasSubstr("Expected instruction to have shape equal to "
1052 "(f32[2,3], f32[2,3])"));
1053 }
1054
TEST_F(HloVerifierTest,AllReduceStartAndMultipleDone)1055 TEST_F(HloVerifierTest, AllReduceStartAndMultipleDone) {
1056 const char* const kModuleStr = R"(
1057 HloModule test
1058 add {
1059 x = f32[] parameter(0)
1060 y = f32[] parameter(1)
1061 ROOT add = f32[] add(x, y)
1062 }
1063 ENTRY entry {
1064 p0 = f32[2,3] parameter(0)
1065 start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1066 done1 = f32[2,3] all-reduce-done(start)
1067 ROOT done2 = f32[2,3] all-reduce-done(start)
1068 }
1069 )";
1070 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1071 ParseAndReturnUnverifiedModule(kModuleStr));
1072
1073 auto status = verifier().Run(module.get()).status();
1074 ASSERT_FALSE(status.ok());
1075 EXPECT_THAT(
1076 status.error_message(),
1077 HasSubstr("all-reduce-start instruction requires one consumer, found 2"));
1078 }
1079
TEST_F(HloVerifierTest,AllReduceDoneWithoutStart)1080 TEST_F(HloVerifierTest, AllReduceDoneWithoutStart) {
1081 const char* const kModuleStr = R"(
1082 HloModule test
1083 ENTRY entry {
1084 p0 = f32[2,3] parameter(0)
1085 p1 = u32[] parameter(1)
1086 tuple = (f32[2,3], f32[2,3]) tuple(p0, p0, p1, p1)
1087 ROOT done = f32[2,3] all-reduce-done(tuple)
1088 }
1089 )";
1090 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1091 ParseAndReturnUnverifiedModule(kModuleStr));
1092
1093 auto status = verifier().Run(module.get()).status();
1094 ASSERT_FALSE(status.ok());
1095 EXPECT_THAT(status.error_message(),
1096 HasSubstr("The operand of a all-reduce-done instruction "
1097 "needs to be all-reduce-start, found tuple"));
1098 }
1099
MakeAllToAllComputation(std::vector<std::vector<int64>> replica_groups,absl::optional<int64> replica_count=absl::nullopt,absl::optional<int64> num_partitions=absl::nullopt,absl::string_view other_attributes="")1100 StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
1101 std::vector<std::vector<int64>> replica_groups,
1102 absl::optional<int64> replica_count = absl::nullopt,
1103 absl::optional<int64> num_partitions = absl::nullopt,
1104 absl::string_view other_attributes = "") {
1105 const char* kTemplate = R"(
1106 HloModule test
1107 ENTRY entry {
1108 p0 = f32[128]{0} parameter(0)
1109 p1 = f32[128]{0} parameter(1)
1110 a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
1111 OTHER_ATTRIBUTES
1112 })";
1113 return MakeCollectiveCommOpComputation(replica_groups, replica_count,
1114 num_partitions, other_attributes,
1115 kTemplate);
1116 }
1117
TEST_F(HloVerifierTest,AllToAll_NoReplicaGroupsOK)1118 TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
1119 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
1120 TF_ASSERT_OK(verifier().Run(module.get()).status());
1121 }
1122
TEST_F(HloVerifierTest,AllToAll_EmptyReplicaGroup)1123 TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
1124 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
1125 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1126 HasSubstr("cannot have an empty replica group"));
1127 }
1128
TEST_F(HloVerifierTest,AllToAll_RepeatedReplicaId)1129 TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
1130 TF_ASSERT_OK_AND_ASSIGN(auto module,
1131 MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
1132 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1133 HasSubstr("Replica 0 is repeated"));
1134 }
1135
TEST_F(HloVerifierTest,AllToAll_MissingReplicaId)1136 TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
1137 TF_ASSERT_OK_AND_ASSIGN(auto module,
1138 MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
1139 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1140 HasSubstr("Replica 4 is not named"));
1141 }
1142
TEST_F(HloVerifierTest,AllToAll_UniformSizeOfReplicasInGroup)1143 TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
1144 TF_ASSERT_OK_AND_ASSIGN(auto module,
1145 MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
1146 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1147 HasSubstr("Replica groups expected to be of uniform size"));
1148 }
1149
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Invalid)1150 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
1151 TF_ASSERT_OK_AND_ASSIGN(
1152 auto module,
1153 MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
1154 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1155 HasSubstr("In kCrossPartition mode, replica groups should "
1156 "contain 2 partitions, but found 4"));
1157 }
1158
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Valid)1159 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
1160 TF_ASSERT_OK_AND_ASSIGN(
1161 auto module,
1162 MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
1163 TF_ASSERT_OK(verifier().Run(module.get()).status());
1164 }
1165
TEST_F(HloVerifierTest,AllToAll_LayoutConstrained)1166 TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
1167 const char* const kModuleStr = R"(
1168 HloModule test
1169 ENTRY entry {
1170 p0 = f32[128,4]{0,1} parameter(0)
1171 p1 = f32[128,4]{1,0} parameter(1)
1172 ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
1173 replica_groups={{0,1}}
1174 }
1175 )";
1176 HloModuleConfig config;
1177 config.set_replica_count(2);
1178 TF_ASSERT_OK_AND_ASSIGN(auto module,
1179 ParseAndReturnUnverifiedModule(kModuleStr, config));
1180 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1181 HasSubstr("HLO all-to-all has operands with different shapes"));
1182 }
1183
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTwice)1184 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
1185 const char* const kModuleStr = R"(
1186 HloModule test
1187 ENTRY entry {
1188 p0 = f32[128] parameter(0)
1189 ROOT permute = f32[128] collective-permute(p0),
1190 source_target_pairs={{0,1}, {0,2}, {1,0}}
1191 }
1192 )";
1193 HloModuleConfig config;
1194 config.set_replica_count(3);
1195 TF_ASSERT_OK_AND_ASSIGN(auto module,
1196 ParseAndReturnUnverifiedModule(kModuleStr, config));
1197 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1198 HasSubstr("Source 0 appears more than once"));
1199 }
1200
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTwice)1201 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
1202 const char* const kModuleStr = R"(
1203 HloModule test
1204 ENTRY entry {
1205 p0 = f32[128] parameter(0)
1206 ROOT permute = f32[128] collective-permute(p0),
1207 source_target_pairs={{0,2}, {1,2}, {2,0}}
1208 }
1209 )";
1210 TF_ASSERT_OK_AND_ASSIGN(auto module,
1211 ParseAndReturnUnverifiedModule(kModuleStr));
1212 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1213 HasSubstr("Target 2 appears more than once"));
1214 }
1215
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTooManyTimes)1216 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTooManyTimes) {
1217 const char* const kModuleStr = R"(
1218 HloModule test
1219 ENTRY entry {
1220 replica_id = u32[] replica-id()
1221 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1222 constant.1 = u32[] constant(1000)
1223 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1224 constant.2 = s32[] constant(0)
1225 constant.3 = s32[] constant(1)
1226 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1227 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1228 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1229 ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{0,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1230 }
1231 )";
1232 TF_ASSERT_OK_AND_ASSIGN(auto module,
1233 ParseAndReturnUnverifiedModule(kModuleStr));
1234 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1235 HasSubstr("Source 0 appears more than 2 times in instruction's "
1236 "source-target pairs:"));
1237 }
1238
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTooManyTimes)1239 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTooManyTimes) {
1240 const char* const kModuleStr = R"(
1241 HloModule test
1242 ENTRY entry {
1243 replica_id = u32[] replica-id()
1244 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1245 constant.1 = u32[] constant(1000)
1246 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1247 constant.2 = s32[] constant(0)
1248 constant.3 = s32[] constant(1)
1249 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1250 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1251 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1252 ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,3},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1253 }
1254 )";
1255 TF_ASSERT_OK_AND_ASSIGN(auto module,
1256 ParseAndReturnUnverifiedModule(kModuleStr));
1257 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1258 HasSubstr("Target 3 appears more than 2 times in instruction's "
1259 "source-target pairs:"));
1260 }
1261
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingSourceTarget)1262 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingSourceTarget) {
1263 const char* const kModuleStr = R"(
1264 HloModule test
1265 ENTRY entry {
1266 replica_id = u32[] replica-id()
1267 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1268 constant.1 = u32[] constant(1000)
1269 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1270 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1271 constant.2 = s32[] constant(0)
1272 constant.3 = s32[] constant(1)
1273 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1274 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1275 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1276 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1277 constant.4 = s32[] constant(2)
1278 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1279 tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1280 tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1281 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1282 }
1283 )";
1284 TF_ASSERT_OK_AND_ASSIGN(auto module,
1285 ParseAndReturnUnverifiedModule(kModuleStr));
1286 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1287 HasSubstr("Unmatching input buffers and output buffers"));
1288 }
1289
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingInputAndInputOffset)1290 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingInputAndInputOffset) {
1291 const char* const kModuleStr = R"(
1292 HloModule test
1293 ENTRY entry {
1294 replica_id = u32[] replica-id()
1295 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1296 constant.1 = u32[] constant(1000)
1297 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1298 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1299 constant.2 = s32[] constant(0)
1300 constant.3 = s32[] constant(1)
1301 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1302 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1303 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1304 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1305 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1306 constant.4 = s32[] constant(2)
1307 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1308 tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1309 tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1310 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (s32[],s32[],s32[]) tuple.3, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1311 }
1312 )";
1313 TF_ASSERT_OK_AND_ASSIGN(auto module,
1314 ParseAndReturnUnverifiedModule(kModuleStr));
1315 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1316 HasSubstr("Unmatching input buffers and input offset."));
1317 }
1318
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingOutputAndOutputOffset)1319 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingOutputAndOutputOffset) {
1320 const char* const kModuleStr = R"(
1321 HloModule test
1322 ENTRY entry {
1323 replica_id = u32[] replica-id()
1324 broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1325 constant.1 = u32[] constant(1000)
1326 broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1327 broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1328 constant.2 = s32[] constant(0)
1329 constant.3 = s32[] constant(1)
1330 tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1331 tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1332 tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1333 tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1334 tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1335 constant.4 = s32[] constant(2)
1336 tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1337 tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
1338 tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
1339 ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (s32[],s32[],s32[]) tuple.2), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1340 }
1341 )";
1342 TF_ASSERT_OK_AND_ASSIGN(auto module,
1343 ParseAndReturnUnverifiedModule(kModuleStr));
1344 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1345 HasSubstr("Unmatching output buffers and output offset."));
1346 }
1347
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaSourceOOR)1348 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaSourceOOR) {
1349 const char* const kModuleStr = R"(
1350 HloModule test
1351 ENTRY entry {
1352 p0 = f32[128] parameter(0)
1353 ROOT permute = f32[128] collective-permute(p0),
1354 source_target_pairs={{5,2}, {1,2}, {2,0}}
1355 }
1356 )";
1357 HloModuleConfig config;
1358 config.set_replica_count(3);
1359 TF_ASSERT_OK_AND_ASSIGN(auto module,
1360 ParseAndReturnUnverifiedModule(kModuleStr, config));
1361 const std::string error_message =
1362 verifier().Run(module.get()).status().error_message();
1363 EXPECT_THAT(error_message, HasSubstr("Source 5"));
1364 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1365 }
1366
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaTargetOOR)1367 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaTargetOOR) {
1368 const char* const kModuleStr = R"(
1369 HloModule test
1370 ENTRY entry {
1371 p0 = f32[128] parameter(0)
1372 ROOT permute = f32[128] collective-permute(p0),
1373 source_target_pairs={{0,1}, {1,2}, {2,7}}
1374 }
1375 )";
1376 HloModuleConfig config;
1377 config.set_replica_count(3);
1378 TF_ASSERT_OK_AND_ASSIGN(auto module,
1379 ParseAndReturnUnverifiedModule(kModuleStr, config));
1380 const std::string error_message =
1381 verifier().Run(module.get()).status().error_message();
1382 EXPECT_THAT(error_message, HasSubstr("Target 7"));
1383 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1384 }
1385
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionSourceOOR)1386 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionSourceOOR) {
1387 const char* const kModuleStr = R"(
1388 HloModule test
1389 ENTRY entry {
1390 p0 = f32[128] parameter(0)
1391 ROOT permute = f32[128] collective-permute(p0),
1392 source_target_pairs={{5,2}, {1,2}, {2,0}}, channel_id=1
1393 }
1394 )";
1395 HloModuleConfig config;
1396 config.set_num_partitions(3);
1397 TF_ASSERT_OK_AND_ASSIGN(auto module,
1398 ParseAndReturnUnverifiedModule(kModuleStr, config));
1399 const std::string error_message =
1400 verifier().Run(module.get()).status().error_message();
1401 EXPECT_THAT(error_message, HasSubstr("Source 5"));
1402 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1403 }
1404
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionTargetOOR)1405 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) {
1406 const char* const kModuleStr = R"(
1407 HloModule test
1408 ENTRY entry {
1409 p0 = f32[128] parameter(0)
1410 ROOT permute = f32[128] collective-permute(p0),
1411 source_target_pairs={{0,2}, {1,7}, {2,0}}, channel_id=1
1412 }
1413 )";
1414 HloModuleConfig config;
1415 config.set_num_partitions(3);
1416 TF_ASSERT_OK_AND_ASSIGN(auto module,
1417 ParseAndReturnUnverifiedModule(kModuleStr, config));
1418 const std::string error_message =
1419 verifier().Run(module.get()).status().error_message();
1420 EXPECT_THAT(error_message, HasSubstr("Target 7"));
1421 EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1422 }
1423
TEST_F(HloVerifierTest,FusionShapeVerifier)1424 TEST_F(HloVerifierTest, FusionShapeVerifier) {
1425 const char* const kModuleStr = R"(
1426 HloModule test
1427
1428 fused_computation {
1429 ROOT p0 = f32[10,10] parameter(0)
1430 }
1431
1432 ENTRY entry {
1433 p0 = f32[10,10] parameter(0)
1434 ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
1435 }
1436 )";
1437 TF_ASSERT_OK_AND_ASSIGN(auto module,
1438 ParseAndReturnUnverifiedModule(kModuleStr));
1439 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1440 HasSubstr("Fused computation shape"));
1441 }
1442
TEST_F(HloVerifierTest,AllReduceVerifier)1443 TEST_F(HloVerifierTest, AllReduceVerifier) {
1444 const char* const kModuleStr = R"(
1445 HloModule test
1446
1447 add {
1448 lhs = f32[] parameter(0)
1449 rhs = f32[] parameter(1)
1450 ROOT add = f32[] add(lhs, rhs)
1451 }
1452
1453 ENTRY entry {
1454 input = f32[8,12]{0,1} parameter(0)
1455 crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
1456 crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
1457 constrain_layout=true
1458 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
1459 }
1460 )";
1461 TF_ASSERT_OK_AND_ASSIGN(auto module,
1462 ParseAndReturnUnverifiedModule(kModuleStr));
1463 EXPECT_THAT(
1464 verifier().Run(module.get()).status().error_message(),
1465 HasSubstr("mix of layout constrained and unconstrained AllReduce"));
1466 }
1467
TEST_F(HloVerifierTest,ChannelVerifier)1468 TEST_F(HloVerifierTest, ChannelVerifier) {
1469 const char* const kModuleStr = R"(
1470 HloModule test
1471
1472 add {
1473 lhs = f32[] parameter(0)
1474 rhs = f32[] parameter(1)
1475 ROOT add = f32[] add(lhs, rhs)
1476 }
1477
1478 ENTRY entry {
1479 %input = f32[8,12] parameter(0)
1480 %token0 = token[] after-all()
1481 %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
1482 %send-done = token[] send-done(%send), channel_id=1
1483 %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1484 channel_id=1
1485 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
1486 }
1487 )";
1488 TF_ASSERT_OK_AND_ASSIGN(auto module,
1489 ParseAndReturnUnverifiedModule(kModuleStr));
1490 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1491 HasSubstr("used for different types of channel instructions"));
1492 }
1493
TEST_F(HloVerifierTest,CollectiveChannelVerifier)1494 TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
1495 const char* const kModuleStr = R"(
1496 HloModule test
1497
1498 add {
1499 lhs = f32[] parameter(0)
1500 rhs = f32[] parameter(1)
1501 ROOT add = f32[] add(lhs, rhs)
1502 }
1503
1504 ENTRY entry {
1505 %input = f32[8,12] parameter(0)
1506 %permute = f32[8,12] collective-permute(%input),
1507 source_target_pairs={{0,1},{1,0}}, channel_id=1
1508 %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1509 channel_id=1
1510 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
1511 }
1512 )";
1513 TF_ASSERT_OK_AND_ASSIGN(auto module,
1514 ParseAndReturnUnverifiedModule(kModuleStr));
1515 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1516 HasSubstr("used for different types of channel instructions"));
1517 }
1518
TEST_F(HloVerifierTestLayoutSensitive,CollectivePermuteStartAndDone)1519 TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
1520 const char* const kModuleStr = R"(
1521 HloModule Module
1522
1523 ENTRY CollectivePermuteStartAndDone {
1524 p0 = f32[2,3]{1,0:S(1)} parameter(0)
1525 collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
1526 ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1527 }
1528 )";
1529 TF_ASSERT_OK_AND_ASSIGN(auto module,
1530 ParseAndReturnUnverifiedModule(kModuleStr));
1531
1532 auto status = verifier().Run(module.get()).status();
1533 ASSERT_TRUE(status.ok());
1534 }
1535
TEST_F(HloVerifierTest,CollectivePermuteStartAndDoneWrongType)1536 TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
1537 const char* const kModuleStr = R"(
1538 HloModule Module
1539
1540 ENTRY CollectivePermuteStartAndDoneWrongType {
1541 p0 = f32[2,3]{1,0:S(1)} parameter(0)
1542 collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
1543 ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1544 }
1545 )";
1546 TF_ASSERT_OK_AND_ASSIGN(auto module,
1547 ParseAndReturnUnverifiedModule(kModuleStr));
1548
1549 auto status = verifier().Run(module.get()).status();
1550 ASSERT_FALSE(status.ok());
1551 EXPECT_THAT(status.error_message(),
1552 HasSubstr("Expected instruction to have shape equal to "
1553 "(f32[2,3], f32[2,3], u32[], u32[])"));
1554 }
1555
TEST_F(HloVerifierTest,CollectivePermuteStartAndMultipleDone)1556 TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
1557 const char* const kModuleStr = R"(
1558 HloModule Module
1559
1560 ENTRY CollectivePermuteStartAndMultipleDone {
1561 p0 = f32[2,3]{1,0:S(1)} parameter(0)
1562 collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
1563 collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1564 ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1565 }
1566 )";
1567 TF_ASSERT_OK_AND_ASSIGN(auto module,
1568 ParseAndReturnUnverifiedModule(kModuleStr));
1569
1570 auto status = verifier().Run(module.get()).status();
1571 ASSERT_FALSE(status.ok());
1572 EXPECT_THAT(
1573 status.error_message(),
1574 HasSubstr("collective-permute-start instruction requires one consumer, "
1575 "found 2"));
1576 }
1577
TEST_F(HloVerifierTest,CollectivePermuteDoneNoCollectivePermuteStart)1578 TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
1579 const char* const kModuleStr = R"(
1580 HloModule Module
1581
1582 ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
1583 p0 = f32[2,3]{1,0:S(1)} parameter(0)
1584 p1 = f32[2,3]{1,0:S(1)} parameter(1)
1585 p2 = u32[] parameter(2)
1586 p3 = u32[] parameter(3)
1587 tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
1588 ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
1589 }
1590 )";
1591 TF_ASSERT_OK_AND_ASSIGN(auto module,
1592 ParseAndReturnUnverifiedModule(kModuleStr));
1593
1594 auto status = verifier().Run(module.get()).status();
1595 ASSERT_FALSE(status.ok());
1596 EXPECT_THAT(status.error_message(),
1597 HasSubstr("The operand of a collective-permute-done instruction "
1598 "needs to be collective-permute-start, found tuple"));
1599 }
1600
TEST_F(HloVerifierTest,ComparisonTypeFloat)1601 TEST_F(HloVerifierTest, ComparisonTypeFloat) {
1602 const char* const hlo_string = R"(
1603 HloModule Module
1604
1605 ENTRY RngOperandElementTypesNotMatch {
1606 p0 = f32[] parameter(0)
1607 ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
1608 }
1609 )";
1610 TF_ASSERT_OK_AND_ASSIGN(auto module,
1611 ParseAndReturnUnverifiedModule(hlo_string));
1612
1613 auto status = verifier().Run(module.get()).status();
1614 ASSERT_FALSE(status.ok());
1615 EXPECT_THAT(status.error_message(),
1616 HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
1617 }
1618
TEST_F(HloVerifierTest,ComparisonTypeSigned)1619 TEST_F(HloVerifierTest, ComparisonTypeSigned) {
1620 const char* const hlo_string = R"(
1621 HloModule Module
1622
1623 ENTRY RngOperandElementTypesNotMatch {
1624 p0 = s32[] parameter(0)
1625 ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
1626 }
1627 )";
1628 TF_ASSERT_OK_AND_ASSIGN(auto module,
1629 ParseAndReturnUnverifiedModule(hlo_string));
1630
1631 auto status = verifier().Run(module.get()).status();
1632 ASSERT_FALSE(status.ok());
1633 EXPECT_THAT(status.error_message(),
1634 HasSubstr("Expected comparison type SIGNED"));
1635 }
1636
TEST_F(HloVerifierTest,ComparisonTypeUnsigned)1637 TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
1638 const char* const hlo_string = R"(
1639 HloModule Module
1640
1641 ENTRY RngOperandElementTypesNotMatch {
1642 p0 = u32[] parameter(0)
1643 ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
1644 }
1645 )";
1646 TF_ASSERT_OK_AND_ASSIGN(auto module,
1647 ParseAndReturnUnverifiedModule(hlo_string));
1648
1649 auto status = verifier().Run(module.get()).status();
1650 ASSERT_FALSE(status.ok());
1651 EXPECT_THAT(status.error_message(),
1652 HasSubstr("Expected comparison type UNSIGNED"));
1653 }
1654
TEST_F(HloVerifierTest,ComparisonTypePred)1655 TEST_F(HloVerifierTest, ComparisonTypePred) {
1656 const char* const hlo_string = R"(
1657 HloModule Module
1658
1659 ENTRY RngOperandElementTypesNotMatch {
1660 p0 = pred[] parameter(0)
1661 ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
1662 }
1663 )";
1664 TF_ASSERT_OK_AND_ASSIGN(auto module,
1665 ParseAndReturnUnverifiedModule(hlo_string));
1666
1667 auto status = verifier().Run(module.get()).status();
1668 ASSERT_FALSE(status.ok());
1669 EXPECT_THAT(status.error_message(),
1670 HasSubstr("Expected comparison type UNSIGNED"));
1671 }
1672
TEST_F(HloVerifierTest,UseGlobalDeviceIdsEmptyReplicaGroup)1673 TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
1674 const char* const hlo_string = R"(
1675 HloModule Module
1676 add {
1677 lhs = f32[] parameter(0)
1678 rhs = f32[] parameter(1)
1679 ROOT add = f32[] add(lhs, rhs)
1680 }
1681
1682 ENTRY CRS {
1683 input = f32[8]{0} parameter(0)
1684 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
1685 use_global_device_ids=true, to_apply=add
1686 })";
1687 TF_ASSERT_OK_AND_ASSIGN(auto module,
1688 ParseAndReturnUnverifiedModule(hlo_string));
1689
1690 auto status = verifier().Run(module.get()).status();
1691 ASSERT_FALSE(status.ok());
1692 EXPECT_THAT(
1693 status.error_message(),
1694 HasSubstr("Replica groups must be specified in flattened-id mode"));
1695 }
1696
TEST_F(HloVerifierTest,InvalidChannelIDandUseGlobalDeviceIDs)1697 TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
1698 const char* const hlo_string = R"(
1699 HloModule Module
1700 add {
1701 lhs = f32[] parameter(0)
1702 rhs = f32[] parameter(1)
1703 ROOT add = f32[] add(lhs, rhs)
1704 }
1705
1706 ENTRY CRS {
1707 input = f32[8]{0} parameter(0)
1708 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
1709 use_global_device_ids=true, to_apply=add
1710 })";
1711 TF_ASSERT_OK_AND_ASSIGN(auto module,
1712 ParseAndReturnUnverifiedModule(hlo_string));
1713
1714 auto status = verifier().Run(module.get()).status();
1715 ASSERT_FALSE(status.ok());
1716 EXPECT_THAT(
1717 status.error_message(),
1718 HasSubstr(
1719 "Invalid combination of has_channel_id and use_global_device_ids"));
1720 }
1721
TEST_F(HloVerifierTest,ReduceScatterInvalidOutputSize0)1722 TEST_F(HloVerifierTest, ReduceScatterInvalidOutputSize0) {
1723 const char* const hlo_string = R"(
1724 HloModule Module
1725 add {
1726 lhs = f32[] parameter(0)
1727 rhs = f32[] parameter(1)
1728 ROOT add = f32[] add(lhs, rhs)
1729 }
1730
1731 ENTRY CRS {
1732 input = f32[8]{0} parameter(0)
1733 ROOT crs = f32[8]{0} reduce-scatter(input), replica_groups={{0,1}},
1734 to_apply=add, dimensions={0}
1735 })";
1736 TF_ASSERT_OK_AND_ASSIGN(auto module,
1737 ParseAndReturnUnverifiedModule(hlo_string));
1738
1739 auto status = verifier().Run(module.get()).status();
1740 ASSERT_FALSE(status.ok());
1741 EXPECT_THAT(status.error_message(),
1742 HasSubstr("shard_count = 1, subgroup_size = 2"));
1743 }
1744
TEST_F(HloVerifierTest,ReduceScatterInvalidScatterDim)1745 TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) {
1746 const char* const hlo_string = R"(
1747 HloModule Module
1748 add {
1749 lhs = f32[] parameter(0)
1750 rhs = f32[] parameter(1)
1751 ROOT add = f32[] add(lhs, rhs)
1752 }
1753
1754 ENTRY CRS {
1755 input = f32[8]{0} parameter(0)
1756 ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}},
1757 to_apply=add, dimensions={1}
1758 })";
1759 TF_ASSERT_OK_AND_ASSIGN(auto module,
1760 ParseAndReturnUnverifiedModule(hlo_string));
1761
1762 auto status = verifier().Run(module.get()).status();
1763 ASSERT_FALSE(status.ok());
1764 EXPECT_THAT(
1765 status.error_message(),
1766 HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()"));
1767 }
1768
TEST_F(HloVerifierTest,ReduceScatterNonUniformGroups)1769 TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
1770 const char* const hlo_string = R"(
1771 HloModule Module
1772 add {
1773 lhs = f32[] parameter(0)
1774 rhs = f32[] parameter(1)
1775 ROOT add = f32[] add(lhs, rhs)
1776 }
1777
1778 ENTRY CRS {
1779 input = f32[8]{0} parameter(0)
1780 ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}, {2,3,4}},
1781 to_apply=add, dimensions={0}
1782 })";
1783 TF_ASSERT_OK_AND_ASSIGN(auto module,
1784 ParseAndReturnUnverifiedModule(hlo_string));
1785
1786 auto status = verifier().Run(module.get()).status();
1787 ASSERT_FALSE(status.ok());
1788 EXPECT_THAT(status.error_message(),
1789 HasSubstr("Replica groups expected to be of uniform size"));
1790 }
1791
1792 } // namespace
1793 } // namespace xla
1794