1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_replace.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/layout_assignment.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla.pb.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37 namespace {
38
39 using ::testing::HasSubstr;
40
CreateUnverifiedModule()41 std::unique_ptr<HloModule> CreateUnverifiedModule() {
42 return absl::make_unique<HloModule>("module", HloModuleConfig());
43 }
44
45 // This class cannot be converted to use HloTestBase. It explicitly
46 // uses HloTestBase to create and test malformed HLOs.
47 class HloVerifierTest : public HloTestBase {
48 public:
HloVerifierTest()49 HloVerifierTest()
50 : HloTestBase(/*verifier_layout_sensitive=*/false,
51 /*allow_mixed_precision_in_hlo_verifier=*/false) {}
52 };
53
54 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
55 public:
HloVerifierTestAllowMixedPrecision()56 HloVerifierTestAllowMixedPrecision()
57 : HloTestBase(/*verifier_layout_sensitive=*/false,
58 /*allow_mixed_precision_in_hlo_verifier=*/true) {}
59 };
60
61 class HloVerifierTestLayoutSensitive : public HloTestBase {
62 public:
HloVerifierTestLayoutSensitive()63 HloVerifierTestLayoutSensitive()
64 : HloTestBase(/*verifier_layout_sensitive=*/true,
65 /*allow_mixed_precision_in_hlo_verifier=*/false,
66 LayoutAssignment::InstructionCanChangeLayout) {}
67 };
68
TEST_F(HloVerifierTest,NullInstructionParent)69 TEST_F(HloVerifierTest, NullInstructionParent) {
70 HloComputation::Builder builder(TestName());
71 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
72 HloInstruction* param = builder.AddInstruction(
73 HloInstruction::CreateParameter(0, scalar_shape, "param"));
74 HloInstruction* negate = builder.AddInstruction(
75 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
76 auto module = CreateUnverifiedModule();
77 module->AddEntryComputation(builder.Build());
78
79 TF_ASSERT_OK(verifier().Run(module.get()).status());
80
81 negate->set_parent(nullptr);
82
83 auto status = verifier().Run(module.get()).status();
84 ASSERT_FALSE(status.ok());
85 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
86 }
87
TEST_F(HloVerifierTest,NullComputationParent)88 TEST_F(HloVerifierTest, NullComputationParent) {
89 HloComputation::Builder builder(TestName());
90 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
91 HloInstruction* param = builder.AddInstruction(
92 HloInstruction::CreateParameter(0, scalar_shape, "param"));
93 builder.AddInstruction(
94 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
95 auto module = CreateUnverifiedModule();
96 HloComputation* computation = module->AddEntryComputation(builder.Build());
97
98 TF_ASSERT_OK(verifier().Run(module.get()).status());
99
100 computation->set_parent(nullptr);
101
102 auto status = verifier().Run(module.get()).status();
103 ASSERT_FALSE(status.ok());
104 EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
105 }
106
TEST_F(HloVerifierTest,DifferentOperandParents)107 TEST_F(HloVerifierTest, DifferentOperandParents) {
108 HloComputation::Builder builder(TestName());
109 const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
110 HloInstruction* param = builder.AddInstruction(
111 HloInstruction::CreateParameter(0, scalar_shape, "param"));
112 HloInstruction* negate = builder.AddInstruction(
113 HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
114 auto module = CreateUnverifiedModule();
115 module->AddEntryComputation(builder.Build());
116
117 HloComputation::Builder emb_builder(TestName());
118 HloInstruction* emb_param = emb_builder.AddInstruction(
119 HloInstruction::CreateParameter(0, scalar_shape, "param"));
120 module->AddEmbeddedComputation(emb_builder.Build());
121
122 TF_ASSERT_OK(verifier().Run(module.get()).status());
123 TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
124
125 auto status = verifier().Run(module.get()).status();
126 ASSERT_FALSE(status.ok());
127 EXPECT_THAT(status.error_message(),
128 HasSubstr("is in a different computation"));
129 }
130
TEST_F(HloVerifierTest,ResetsShapeVerifierState)131 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
132 HloComputation::Builder builder(TestName());
133 Shape s1 = ShapeUtil::MakeShape(F32, {1});
134 Shape s2 = ShapeUtil::MakeShape(F32, {2});
135
136 HloInstruction* param =
137 builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
138
139 // Create an add instruction with the incorrect shape.
140 HloInstruction* add = builder.AddInstruction(
141 HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
142
143 // In order to trigger the bug we're checking for, the instruction with the
144 // bad shape can't be the root of the computation.
145 builder.AddInstruction(
146 HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
147
148 auto module = CreateUnverifiedModule();
149 module->AddEntryComputation(builder.Build());
150
151 // Run the verifier twice. It should fail both times, because it shouldn't
152 // carry state in its DFS visitor between runs.
153 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
154 EXPECT_FALSE(verifier().Run(module.get()).status().ok());
155 }
156
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)157 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
158 const char* const hlo_string = R"(
159 HloModule Module
160
161 callme {
162 ROOT param = (s32[], f32[4]) parameter(0)
163 }
164
165 ENTRY entry {
166 p0 = (f32[4], s32[]) parameter(0)
167 ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
168 }
169 )";
170 TF_ASSERT_OK_AND_ASSIGN(auto module,
171 ParseAndReturnUnverifiedModule(hlo_string));
172
173 auto status = verifier().Run(module.get()).status();
174 ASSERT_FALSE(status.ok());
175 EXPECT_THAT(status.error_message(),
176 HasSubstr("shape does not match parameter"));
177 }
178
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)179 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
180 const char* const hlo_string = R"(
181 HloModule Module
182
183 true_branch {
184 tparam = (s32[], f32[4]) parameter(0)
185 ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
186 }
187
188 false_branch {
189 fparam = (s32[], f32[4]) parameter(0)
190 ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
191 }
192
193 ENTRY entry {
194 p0 = (f32[4], s32[]) parameter(0)
195 constant = pred[] constant(true)
196 ROOT conditional = f32[4] conditional(constant, p0, p0),
197 true_computation=true_branch, false_computation=false_branch
198 }
199 )";
200 TF_ASSERT_OK_AND_ASSIGN(auto module,
201 ParseAndReturnUnverifiedModule(hlo_string));
202
203 auto status = verifier().Run(module.get()).status();
204 ASSERT_FALSE(status.ok());
205 EXPECT_THAT(status.error_message(),
206 HasSubstr("shape does not match parameter"));
207 }
208
TEST_F(HloVerifierTest,CheckConditionalBranchIndexOperandShape)209 TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
210 const char* const hlo_string = R"(
211 HloModule Module
212
213 branch0 {
214 tparam = f32[4] parameter(0)
215 ROOT tgte1 = f32[4] ceil(tparam)
216 }
217
218 branch1 {
219 fparam = f32[4] parameter(0)
220 ROOT fgte1 = f32[4] floor(fparam)
221 }
222
223 branch2 {
224 sparam = f32[4] parameter(0)
225 ROOT sgte1 = f32[4] ceil(sparam)
226 }
227
228 ENTRY entry {
229 p0 = f32[4] parameter(0)
230 b0 = s32[] parameter(1)
231 ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
232 branch_computations={branch0, branch1, branch2}
233 }
234 )";
235 TF_ASSERT_OK_AND_ASSIGN(auto module,
236 ParseAndReturnUnverifiedModule(hlo_string));
237 auto status = verifier().Run(module.get()).status();
238
239 HloInstruction* condition = FindInstruction(module.get(), "b0");
240 *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
241 status = verifier().Run(module.get()).status();
242 ASSERT_FALSE(status.ok());
243 EXPECT_THAT(
244 status.error_message(),
245 HasSubstr(
246 "first operand of indexed conditional must be a scalar of S32"));
247
248 *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
249 status = verifier().Run(module.get()).status();
250 ASSERT_FALSE(status.ok());
251 EXPECT_THAT(status.error_message(),
252 HasSubstr("first operand of conditional must be a scalar"));
253 }
254
TEST_F(HloVerifierTest,RngOpnd0NotScalar)255 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
256 const char* const hlo_string = R"(
257 HloModule Module
258
259 ENTRY RngOpnd0NotScalar {
260 constant.0 = f32[] constant(0)
261 constant.1 = f16[2] constant({1, 3})
262 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
263 distribution=rng_uniform
264 }
265 )";
266 TF_ASSERT_OK_AND_ASSIGN(auto module,
267 ParseAndReturnUnverifiedModule(hlo_string));
268
269 auto status = verifier().Run(module.get()).status();
270 ASSERT_FALSE(status.ok());
271 EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
272 }
273
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)274 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
275 const char* const hlo_string = R"(
276 HloModule Module
277
278 ENTRY RngOperandElementTypesNotMatch {
279 constant.0 = f32[] constant(0)
280 constant.1 = f16[] constant(1)
281 ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
282 distribution=rng_normal
283 }
284 )";
285 TF_ASSERT_OK_AND_ASSIGN(auto module,
286 ParseAndReturnUnverifiedModule(hlo_string));
287
288 auto status = verifier().Run(module.get()).status();
289 ASSERT_FALSE(status.ok());
290 EXPECT_THAT(status.error_message(),
291 HasSubstr("Expected compatible element types"));
292 }
293
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)294 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
295 const char* const hlo_string = R"(
296 HloModule Module
297
298 ENTRY RngResultElementTypeNotMatch {
299 constant.0 = f32[] constant(0)
300 constant.1 = f32[] constant(1)
301 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
302 distribution=rng_normal
303 }
304 )";
305 TF_ASSERT_OK_AND_ASSIGN(auto module,
306 ParseAndReturnUnverifiedModule(hlo_string));
307
308 auto status = verifier().Run(module.get()).status();
309 ASSERT_FALSE(status.ok());
310 EXPECT_THAT(status.error_message(),
311 HasSubstr("Expected compatible element types"));
312 }
313
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)314 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
315 const char* const hlo_string = R"(
316 HloModule Module
317
318 ENTRY RngResultElementTypeNotMatch {
319 constant.0 = f32[] constant(0)
320 constant.1 = f32[] constant(1)
321 ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
322 distribution=rng_normal
323 }
324 )";
325 TF_ASSERT_OK_AND_ASSIGN(auto module,
326 ParseAndReturnVerifiedModule(hlo_string));
327
328 auto status = verifier().Run(module.get()).status();
329 ASSERT_TRUE(status.ok());
330 }
331
TEST_F(HloVerifierTest,RngElementTypeNotSupported)332 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
333 const char* const hlo_string = R"(
334 HloModule Module
335
336 ENTRY RngElementTypeNotSupported {
337 constant.0 = s32[] constant(0)
338 constant.1 = s32[] constant(1)
339 ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
340 distribution=rng_normal
341 }
342 )";
343 TF_ASSERT_OK_AND_ASSIGN(auto module,
344 ParseAndReturnUnverifiedModule(hlo_string));
345
346 auto status = verifier().Run(module.get()).status();
347 ASSERT_FALSE(status.ok());
348 EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
349 }
350
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)351 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
352 // This testcase can't be written using textual HLO, because it doesn't parse
353 // negative interior padding. That's probably a feature. :)
354 HloComputation::Builder builder(TestName());
355 HloInstruction* param =
356 builder.AddInstruction(HloInstruction::CreateParameter(
357 0, ShapeUtil::MakeShape(F32, {100}), "param"));
358 PaddingConfig padding_config;
359 padding_config.add_dimensions()->set_interior_padding(-1);
360 builder.AddInstruction(HloInstruction::CreatePad(
361 ShapeUtil::MakeShape(F32, {100}), param,
362 builder.AddInstruction(
363 HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
364 padding_config));
365
366 auto module = CreateUnverifiedModule();
367 module->AddEntryComputation(builder.Build());
368
369 auto status = verifier().Run(module.get()).status();
370 ASSERT_FALSE(status.ok());
371 EXPECT_THAT(status.error_message(),
372 HasSubstr("Interior padding cannot be negative"));
373 }
374
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)375 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
376 // This testcase can't be written using textual HLO, because it doesn't parse
377 // negative interior padding. That's probably a feature. :)
378 HloComputation::Builder builder(TestName());
379 HloInstruction* param =
380 builder.AddInstruction(HloInstruction::CreateParameter(
381 0, ShapeUtil::MakeShape(F32, {100}), "param"));
382 PaddingConfig padding_config;
383 padding_config.add_dimensions()->set_interior_padding(-1);
384 builder.AddInstruction(HloInstruction::CreatePad(
385 ShapeUtil::MakeShape(F32, {100}), param,
386 builder.AddInstruction(
387 HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
388 padding_config));
389
390 auto module = CreateUnverifiedModule();
391 module->AddEntryComputation(builder.Build());
392
393 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
394 HasSubstr("Interior padding cannot be negative"));
395 }
396
397 // Simple module containing a convolution as the root.
398 static const char* const kConvHloString = R"(
399 HloModule module
400 ENTRY entry_computation {
401 param0 = f16[128,128,56,56] parameter(0)
402 param1 = f16[3,3,128,128] parameter(1)
403 zero_f16 = f16[] constant(0)
404 ROOT conv = f16[128,128,28,28] convolution(param0, param1),
405 window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
406 })";
407
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)408 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
409 TF_ASSERT_OK_AND_ASSIGN(auto module,
410 ParseAndReturnUnverifiedModule(kConvHloString));
411 auto* conv = module->entry_computation()->root_instruction();
412 Window w = conv->window();
413 w.mutable_dimensions(0)->set_window_dilation(-1);
414 conv->set_window(w);
415
416 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
417 HasSubstr("non-positive window dilation factor"));
418 }
419
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)420 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
421 TF_ASSERT_OK_AND_ASSIGN(auto module,
422 ParseAndReturnUnverifiedModule(kConvHloString));
423 auto* conv = module->entry_computation()->root_instruction();
424 Window w = conv->window();
425 w.mutable_dimensions(0)->set_base_dilation(-1);
426 conv->set_window(w);
427
428 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
429 HasSubstr("non-positive base area dilation factor"));
430 }
431
432 static const char* const kAddWithLayoutChangeHlo = R"(
433 HloModule AddWithLayoutChange
434 ENTRY AddWithLayoutChange {
435 par0 = f32[3,4]{1,0} parameter(0)
436 par1 = f32[3,4]{0,1} parameter(1)
437 ROOT add0 = f32[3,4]{1,0} add(par0,par1)
438 }
439 )";
440
TEST_F(HloVerifierTest,AddWithLayoutChange)441 TEST_F(HloVerifierTest, AddWithLayoutChange) {
442 TF_ASSERT_OK_AND_ASSIGN(
443 auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
444 auto status = verifier().Run(module.get()).status();
445 ASSERT_TRUE(status.ok());
446 }
447
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)448 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
449 const char* const kScalarIndexDynamicSlice = R"(
450 HloModule DynamicSlice_module
451
452 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
453 %original_parameter = s32[2,2,258] parameter(0)
454 %constant = s32[] constant(0)
455 %start_index = s32[] parameter(1)
456 ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
457 }
458 )";
459
460 HloModuleConfig config;
461 DebugOptions debug_options = config.debug_options();
462 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
463 config.set_debug_options(debug_options);
464
465 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
466 kScalarIndexDynamicSlice, config));
467 auto status = verifier().Run(module.get()).status();
468 ASSERT_TRUE(status.ok());
469 }
470
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)471 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
472 const char* const kScalarIndexDynamicSlice = R"(
473 HloModule DynamicUpdateSlice_module
474
475 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
476 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
477 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
478 %start_index.0 = s32[] parameter(2)
479 %start_index.1 = s32[] parameter(3)
480 %start_index.2 = s32[] parameter(4)
481 %start_index.3 = s32[] parameter(5)
482 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
483 }
484 )";
485
486 HloModuleConfig config;
487 DebugOptions debug_options = config.debug_options();
488 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
489 config.set_debug_options(debug_options);
490
491 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
492 kScalarIndexDynamicSlice, config));
493 auto status = verifier().Run(module.get()).status();
494 ASSERT_TRUE(status.ok());
495 }
496
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)497 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
498 TF_ASSERT_OK_AND_ASSIGN(
499 auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
500 auto status = verifier().Run(module.get()).status();
501 ASSERT_FALSE(status.ok());
502 EXPECT_THAT(status.error_message(),
503 HasSubstr("Instruction shouldn't change layouts"));
504 }
505
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)506 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
507 const char* const kSliceWithLayoutChangeHlo = R"(
508 HloModule SliceWithLayoutChange
509 ENTRY SliceWithLayoutChange {
510 par0 = f32[4,5]{0,1} parameter(0)
511 par1 = s32[] parameter(1)
512 par2 = s32[] parameter(2)
513 ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
514 dynamic_slice_sizes={3,4}
515 }
516 )";
517 TF_ASSERT_OK_AND_ASSIGN(
518 auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
519 auto status = verifier().Run(module.get()).status();
520 ASSERT_FALSE(status.ok());
521 EXPECT_THAT(status.error_message(),
522 HasSubstr("Instruction shouldn't change layouts"));
523 }
524
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)525 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
526 const char* const kConcatWithLayoutChangeHlo = R"(
527 HloModule ConcatWithLayoutChange
528 ENTRY ConcatWithLayoutChange {
529 par0 = f32[3,5]{0,1} parameter(0)
530 par1 = f32[3,3]{1,0} parameter(1)
531 ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
532 dimensions={1}
533 }
534 )";
535 TF_ASSERT_OK_AND_ASSIGN(
536 auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
537 auto status = verifier().Run(module.get()).status();
538 ASSERT_FALSE(status.ok());
539 EXPECT_THAT(status.error_message(),
540 HasSubstr("Instruction shouldn't change layouts"));
541 }
542
TEST_F(HloVerifierTest,BitcastCanNotChangeElementType)543 TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) {
544 const char* const hlo_string = R"(
545 HloModule Module
546
547 ENTRY BitcastCanNotChangeElementType {
548 constant.0 = f32[2] constant({0.0, 0.0})
549 ROOT bitcast = s32[2] bitcast(constant.0)
550 }
551 )";
552 TF_ASSERT_OK_AND_ASSIGN(auto module,
553 ParseAndReturnUnverifiedModule(hlo_string));
554
555 auto status = verifier().Run(module.get()).status();
556 ASSERT_FALSE(status.ok());
557 EXPECT_THAT(status.error_message(),
558 HasSubstr("Bitcast can not change the element type"));
559 }
560
TEST_F(HloVerifierTestLayoutSensitive,BitcastNeedsSameNumberOfElements)561 TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
562 const char* const hlo_string = R"(
563 HloModule Module
564
565 ENTRY BitcastNeedsToBeNoOp {
566 constant.0 = f32[2] constant({0.0, 0.0})
567 ROOT bitcast = f32[3] bitcast(constant.0)
568 }
569 )";
570 TF_ASSERT_OK_AND_ASSIGN(auto module,
571 ParseAndReturnUnverifiedModule(hlo_string));
572
573 auto status = verifier().Run(module.get()).status();
574 ASSERT_FALSE(status.ok());
575 EXPECT_THAT(status.error_message(),
576 HasSubstr("Bitcast cannot have different shape sizes of output "
577 "(12) and operand (8)"));
578 }
579
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)580 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
581 const char* const hlo_string = R"(
582 HloModule Module
583
584 ENTRY SelectMixedPrecisionNotAllowed {
585 p0 = pred[32] parameter(0)
586 p1 = f32[32] parameter(1)
587 p2 = bf16[32] parameter(2)
588 ROOT select = f32[32] select(p0, p1, p2)
589 }
590 )";
591 TF_ASSERT_OK_AND_ASSIGN(auto module,
592 ParseAndReturnUnverifiedModule(hlo_string));
593
594 auto status = verifier().Run(module.get()).status();
595 ASSERT_FALSE(status.ok());
596 EXPECT_THAT(status.error_message(),
597 HasSubstr("Seen floating point types of different precisions"));
598 }
599
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)600 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
601 const char* const hlo_string = R"(
602 HloModule Module
603
604 ENTRY SelectMixedPrecisionAllowed {
605 p0 = pred[32] parameter(0)
606 p1 = f32[32] parameter(1)
607 p2 = bf16[32] parameter(2)
608 ROOT select = f32[32] select(p0, p1, p2)
609 }
610 )";
611 TF_ASSERT_OK_AND_ASSIGN(auto module,
612 ParseAndReturnVerifiedModule(hlo_string));
613
614 auto status = verifier().Run(module.get()).status();
615 ASSERT_TRUE(status.ok());
616 }
617
TEST_F(HloVerifierTest,SelectTupleNotAllowed)618 TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
619 const char* const hlo_string = R"(
620 HloModule Module
621
622 ENTRY SelectWithTuple {
623 p0 = (f32[], f32[]) parameter(0)
624 p1 = (f32[], f32[]) parameter(1)
625 p2 = pred[] parameter(2)
626 ROOT select = (f32[], f32[]) select(p2, p0, p1)
627 }
628 )";
629 TF_ASSERT_OK_AND_ASSIGN(auto module,
630 ParseAndReturnUnverifiedModule(hlo_string));
631
632 auto status = verifier().Run(module.get()).status();
633 ASSERT_FALSE(status.ok());
634 EXPECT_THAT(status.error_message(),
635 HasSubstr("Expected array argument for select"));
636 }
637
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDone)638 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
639 const char* const hlo_string = R"(
640 HloModule Module
641
642 ENTRY CopyStartAndCopyDone {
643 p0 = f32[2,3]{1,0:S(1)} parameter(0)
644 copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
645 ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
646 }
647 )";
648 TF_ASSERT_OK_AND_ASSIGN(auto module,
649 ParseAndReturnVerifiedModule(hlo_string));
650
651 auto status = verifier().Run(module.get()).status();
652 ASSERT_TRUE(status.ok());
653 }
654
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDoneWrongLayout)655 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
656 const char* const hlo_string = R"(
657 HloModule Module
658
659 ENTRY CopyStartAndCopyDone {
660 p0 = f32[2,3]{1,0:S(1)} parameter(0)
661 copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
662 ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
663 }
664 )";
665 TF_ASSERT_OK_AND_ASSIGN(auto module,
666 ParseAndReturnUnverifiedModule(hlo_string));
667
668 auto status = verifier().Run(module.get()).status();
669 ASSERT_FALSE(status.ok());
670 EXPECT_THAT(status.error_message(),
671 HasSubstr("Expected instruction to have shape equal to"));
672 }
673
TEST_F(HloVerifierTest,CopyStartAndCopyDoneWrongType)674 TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
675 const char* const hlo_string = R"(
676 HloModule Module
677
678 ENTRY CopyStartAndCopyDone {
679 p0 = f32[2,3] parameter(0)
680 copy-start = f32[2,3] copy-start(p0)
681 ROOT copy-done = f32[2,3] copy-done(copy-start)
682 }
683 )";
684 TF_ASSERT_OK_AND_ASSIGN(auto module,
685 ParseAndReturnUnverifiedModule(hlo_string));
686
687 auto status = verifier().Run(module.get()).status();
688 ASSERT_FALSE(status.ok());
689 EXPECT_THAT(status.error_message(),
690 HasSubstr("Expected instruction to have shape equal to "
691 "(f32[2,3], f32[2,3], u32[])"));
692 }
693
TEST_F(HloVerifierTest,CopyStartMultipleCopyDone)694 TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
695 const char* const hlo_string = R"(
696 HloModule Module
697
698 ENTRY CopyStartAndCopyDone {
699 p0 = f32[2,3] parameter(0)
700 copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
701 copy-done.1 = f32[2,3] copy-done(copy-start)
702 copy-done.2 = f32[2,3] copy-done(copy-start)
703 ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
704 }
705 )";
706 TF_ASSERT_OK_AND_ASSIGN(auto module,
707 ParseAndReturnUnverifiedModule(hlo_string));
708
709 auto status = verifier().Run(module.get()).status();
710 ASSERT_FALSE(status.ok());
711 EXPECT_THAT(
712 status.error_message(),
713 HasSubstr("CopyStart instruction requires one consumer, found 2"));
714 }
715
TEST_F(HloVerifierTest,CopyDoneNoCopyStart)716 TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
717 const char* const hlo_string = R"(
718 HloModule Module
719
720 ENTRY CopyStartAndCopyDone {
721 p0 = f32[2,3] parameter(0)
722 p1 = u32[] parameter(1)
723 tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
724 ROOT copy-done = f32[2,3] copy-done(tuple)
725 }
726 )";
727 TF_ASSERT_OK_AND_ASSIGN(auto module,
728 ParseAndReturnUnverifiedModule(hlo_string));
729
730 auto status = verifier().Run(module.get()).status();
731 ASSERT_FALSE(status.ok());
732 EXPECT_THAT(status.error_message(),
733 HasSubstr("The operand of a CopyDone instruction needs to be "
734 "CopyStart, found tuple"));
735 }
736
TEST_F(HloVerifierTest,IotaNonArrayResult)737 TEST_F(HloVerifierTest, IotaNonArrayResult) {
738 const char* const hlo_string = R"(
739 HloModule IotaTupleResult
740
741 ENTRY kernelEntry {
742 ROOT iota = () iota(), iota_dimension=24
743 }
744 )";
745
746 TF_ASSERT_OK_AND_ASSIGN(auto module,
747 ParseAndReturnUnverifiedModule(hlo_string));
748
749 auto status = verifier().Run(module.get()).status();
750 ASSERT_FALSE(status.ok());
751 EXPECT_THAT(status.error_message(),
752 HasSubstr("does not support non-array result"));
753 }
754
TEST_F(HloVerifierTest,IotaNegativeDimension)755 TEST_F(HloVerifierTest, IotaNegativeDimension) {
756 const char* const hlo_string = R"(
757 HloModule IotaTupleResult
758
759 ENTRY kernelEntry {
760 ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
761 }
762 )";
763
764 TF_ASSERT_OK_AND_ASSIGN(auto module,
765 ParseAndReturnUnverifiedModule(hlo_string));
766
767 auto status = verifier().Run(module.get()).status();
768 ASSERT_FALSE(status.ok());
769 EXPECT_THAT(status.error_message(), HasSubstr("negative"));
770 }
771
TEST_F(HloVerifierTest,IotaPredResultNotAllowed)772 TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
773 const char* const hlo_string = R"(
774 HloModule IotaPredResult
775
776 ENTRY kernelEntry {
777 ROOT iota = pred[128] iota(), iota_dimension=0
778 }
779 )";
780
781 TF_ASSERT_OK_AND_ASSIGN(auto module,
782 ParseAndReturnUnverifiedModule(hlo_string));
783
784 auto status = verifier().Run(module.get()).status();
785 ASSERT_FALSE(status.ok());
786 EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
787 }
788
789 static const char* const kMapOperandComputationMismatchHlo = R"(
790 HloModule MapOperandComputationMismatch
791
792 Computation {
793 param0 = f32[] parameter(0)
794 constant = f32[] constant(1)
795 ROOT add = f32[] add(param0, constant)
796 }
797
798 ENTRY kernelEntry {
799 param = f64[] parameter(0)
800 ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
801 })";
802
TEST_F(HloVerifierTest,MapOperandComputationMismatch)803 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
804 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
805 kMapOperandComputationMismatchHlo));
806 auto status = verifier().Run(module.get()).status();
807 ASSERT_FALSE(status.ok());
808 EXPECT_THAT(
809 status.error_message(),
810 HasSubstr(
811 "Shape mismatch between to_apply computation parameter and operand"));
812 }
813
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)814 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
815 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
816 kMapOperandComputationMismatchHlo));
817 auto status = verifier().Run(module.get()).status();
818 ASSERT_TRUE(status.ok());
819 }
820
821 static const char* const kReduceOperandComputationMismatchHlo = R"(
822 HloModule ReduceOperandComputationMismatch
823 computation {
824 x = f32[] parameter(0)
825 y = f32[] parameter(1)
826 ROOT add = f32[] add(x, y)
827 }
828
829 ENTRY kernelEntry {
830 arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
831 constant = f16[] constant(0)
832 reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
833 })";
834
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)835 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
836 TF_ASSERT_OK_AND_ASSIGN(
837 auto module,
838 ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
839 auto status = verifier().Run(module.get()).status();
840 ASSERT_FALSE(status.ok());
841 EXPECT_THAT(status.error_message(),
842 HasSubstr("Expected instruction to have shape equal to f32[64]"));
843 }
844
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)845 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
846 TF_ASSERT_OK_AND_ASSIGN(
847 auto module,
848 ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
849 auto status = verifier().Run(module.get()).status();
850 ASSERT_TRUE(status.ok());
851 }
852
ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups)853 string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
854 std::vector<string> replica_group_strs;
855 for (const auto& g : replica_groups) {
856 replica_group_strs.push_back(
857 absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
858 }
859 return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
860 }
861
MakeAllReduceComputation(std::vector<std::vector<int64>> replica_groups)862 StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
863 std::vector<std::vector<int64>> replica_groups) {
864 const char* kTemplate = R"(
865 HloModule test
866 add {
867 x = f32[] parameter(0)
868 y = f32[] parameter(1)
869 ROOT add = f32[] add(x, y)
870 }
871 ENTRY entry {
872 p = f32[128]{0} parameter(0)
873 crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
874 })";
875 return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
876 kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
877 }
878
TEST_F(HloVerifierTest,AllReduce_NoReplicaGroupsOK)879 TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
880 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
881 TF_ASSERT_OK(verifier().Run(module.get()).status());
882 }
883
TEST_F(HloVerifierTest,AllReduce_DifferentGroupSizesOk)884 TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
885 TF_ASSERT_OK_AND_ASSIGN(auto module,
886 MakeAllReduceComputation({{0}, {1, 3}, {2}}));
887 TF_ASSERT_OK(verifier().Run(module.get()).status());
888 }
889
TEST_F(HloVerifierTest,AllReduce_EmptyReplicaGroup)890 TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
891 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
892 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
893 HasSubstr("empty replica group"));
894 }
895
TEST_F(HloVerifierTest,AllReduce_RepeatedReplicaId)896 TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
897 TF_ASSERT_OK_AND_ASSIGN(auto module,
898 MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
899 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
900 HasSubstr("Replica 0 is repeated"));
901 }
902
TEST_F(HloVerifierTest,AllReduce_MissingReplicaId)903 TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
904 TF_ASSERT_OK_AND_ASSIGN(auto module,
905 MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
906 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
907 HasSubstr("Replica 4 is not named"));
908 }
909
MakeAllToAllComputation(std::vector<std::vector<int64>> replica_groups)910 StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
911 std::vector<std::vector<int64>> replica_groups) {
912 const char* kTemplate = R"(
913 HloModule test
914 add {
915 x = f32[] parameter(0)
916 y = f32[] parameter(1)
917 ROOT add = f32[] add(x, y)
918 }
919 ENTRY entry {
920 p0 = f32[128]{0} parameter(0)
921 p1 = f32[128]{0} parameter(1)
922 a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
923 })";
924 return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
925 kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
926 }
927
TEST_F(HloVerifierTest,AllToAll_NoReplicaGroupsOK)928 TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
929 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
930 TF_ASSERT_OK(verifier().Run(module.get()).status());
931 }
932
TEST_F(HloVerifierTest,AllToAll_EmptyReplicaGroup)933 TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
934 TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
935 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
936 HasSubstr("empty replica group"));
937 }
938
TEST_F(HloVerifierTest,AllToAll_RepeatedReplicaId)939 TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
940 TF_ASSERT_OK_AND_ASSIGN(auto module,
941 MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
942 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
943 HasSubstr("Replica 0 is repeated"));
944 }
945
TEST_F(HloVerifierTest,AllToAll_MissingReplicaId)946 TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
947 TF_ASSERT_OK_AND_ASSIGN(auto module,
948 MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
949 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
950 HasSubstr("Replica 4 is not named"));
951 }
952
TEST_F(HloVerifierTest,AllToAll_WrongNumberOfReplicasInGroup)953 TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
954 TF_ASSERT_OK_AND_ASSIGN(auto module,
955 MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
956 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
957 HasSubstr("Replica group has size 1"));
958 }
959
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTwice)960 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
961 const char* const kModuleStr = R"(
962 HloModule test
963 ENTRY entry {
964 p0 = f32[128] parameter(0)
965 ROOT permute = f32[128] collective-permute(p0),
966 source_target_pairs={{0,1}, {0,2}, {1,0}}
967 }
968 )";
969 TF_ASSERT_OK_AND_ASSIGN(auto module,
970 ParseAndReturnUnverifiedModule(kModuleStr));
971 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
972 HasSubstr("Source 0 appears more than once"));
973 }
974
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTwice)975 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
976 const char* const kModuleStr = R"(
977 HloModule test
978 ENTRY entry {
979 p0 = f32[128] parameter(0)
980 ROOT permute = f32[128] collective-permute(p0),
981 source_target_pairs={{0,2}, {1,2}, {2,0}}
982 }
983 )";
984 TF_ASSERT_OK_AND_ASSIGN(auto module,
985 ParseAndReturnUnverifiedModule(kModuleStr));
986 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
987 HasSubstr("Target 2 appears more than once"));
988 }
989
TEST_F(HloVerifierTest,FusionShapeVerifier)990 TEST_F(HloVerifierTest, FusionShapeVerifier) {
991 const char* const kModuleStr = R"(
992 HloModule test
993
994 fused_computation {
995 ROOT p0 = f32[10,10] parameter(0)
996 }
997
998 ENTRY entry {
999 p0 = f32[10,10] parameter(0)
1000 ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
1001 }
1002 )";
1003 TF_ASSERT_OK_AND_ASSIGN(auto module,
1004 ParseAndReturnUnverifiedModule(kModuleStr));
1005 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1006 HasSubstr("Fused computation shape"));
1007 }
1008
TEST_F(HloVerifierTest,AllReduceVerifier)1009 TEST_F(HloVerifierTest, AllReduceVerifier) {
1010 const char* const kModuleStr = R"(
1011 HloModule test
1012
1013 add {
1014 lhs = f32[] parameter(0)
1015 rhs = f32[] parameter(1)
1016 ROOT add = f32[] add(lhs, rhs)
1017 }
1018
1019 ENTRY entry {
1020 input = f32[8,12]{0,1} parameter(0)
1021 crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
1022 crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
1023 constrain_layout=true
1024 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
1025 }
1026 )";
1027 TF_ASSERT_OK_AND_ASSIGN(auto module,
1028 ParseAndReturnUnverifiedModule(kModuleStr));
1029 EXPECT_THAT(
1030 verifier().Run(module.get()).status().error_message(),
1031 HasSubstr("mix of layout constrained and unconstrained AllReduce"));
1032 }
1033
TEST_F(HloVerifierTest,ChannelVerifier)1034 TEST_F(HloVerifierTest, ChannelVerifier) {
1035 const char* const kModuleStr = R"(
1036 HloModule test
1037
1038 add {
1039 lhs = f32[] parameter(0)
1040 rhs = f32[] parameter(1)
1041 ROOT add = f32[] add(lhs, rhs)
1042 }
1043
1044 ENTRY entry {
1045 %input = f32[8,12] parameter(0)
1046 %token0 = token[] after-all()
1047 %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
1048 %send-done = token[] send-done(%send), channel_id=1
1049 %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1050 channel_id=1
1051 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
1052 }
1053 )";
1054 TF_ASSERT_OK_AND_ASSIGN(auto module,
1055 ParseAndReturnUnverifiedModule(kModuleStr));
1056 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1057 HasSubstr("used for different types of channel instructions"));
1058 }
1059
TEST_F(HloVerifierTest,CollectiveChannelVerifier)1060 TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
1061 const char* const kModuleStr = R"(
1062 HloModule test
1063
1064 add {
1065 lhs = f32[] parameter(0)
1066 rhs = f32[] parameter(1)
1067 ROOT add = f32[] add(lhs, rhs)
1068 }
1069
1070 ENTRY entry {
1071 %input = f32[8,12] parameter(0)
1072 %permute = f32[8,12] collective-permute(%input),
1073 source_target_pairs={{0,1},{1,0}}, channel_id=1
1074 %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1075 channel_id=1
1076 ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
1077 }
1078 )";
1079 TF_ASSERT_OK_AND_ASSIGN(auto module,
1080 ParseAndReturnUnverifiedModule(kModuleStr));
1081 EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1082 HasSubstr("used for different types of channel instructions"));
1083 }
1084
1085 } // namespace
1086 } // namespace xla
1087