• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_replace.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/layout_assignment.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla.pb.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 using ::testing::HasSubstr;
40 
CreateUnverifiedModule()41 std::unique_ptr<HloModule> CreateUnverifiedModule() {
42   return absl::make_unique<HloModule>("module", HloModuleConfig());
43 }
44 
45 // This class cannot be converted to use HloTestBase. It explicitly
46 // uses HloTestBase to create and test malformed HLOs.
47 class HloVerifierTest : public HloTestBase {
48  public:
HloVerifierTest()49   HloVerifierTest()
50       : HloTestBase(/*verifier_layout_sensitive=*/false,
51                     /*allow_mixed_precision_in_hlo_verifier=*/false) {}
52 };
53 
54 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
55  public:
HloVerifierTestAllowMixedPrecision()56   HloVerifierTestAllowMixedPrecision()
57       : HloTestBase(/*verifier_layout_sensitive=*/false,
58                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
59 };
60 
61 class HloVerifierTestLayoutSensitive : public HloTestBase {
62  public:
HloVerifierTestLayoutSensitive()63   HloVerifierTestLayoutSensitive()
64       : HloTestBase(/*verifier_layout_sensitive=*/true,
65                     /*allow_mixed_precision_in_hlo_verifier=*/false,
66                     LayoutAssignment::InstructionCanChangeLayout) {}
67 };
68 
TEST_F(HloVerifierTest,NullInstructionParent)69 TEST_F(HloVerifierTest, NullInstructionParent) {
70   HloComputation::Builder builder(TestName());
71   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
72   HloInstruction* param = builder.AddInstruction(
73       HloInstruction::CreateParameter(0, scalar_shape, "param"));
74   HloInstruction* negate = builder.AddInstruction(
75       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
76   auto module = CreateUnverifiedModule();
77   module->AddEntryComputation(builder.Build());
78 
79   TF_ASSERT_OK(verifier().Run(module.get()).status());
80 
81   negate->set_parent(nullptr);
82 
83   auto status = verifier().Run(module.get()).status();
84   ASSERT_FALSE(status.ok());
85   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
86 }
87 
TEST_F(HloVerifierTest,NullComputationParent)88 TEST_F(HloVerifierTest, NullComputationParent) {
89   HloComputation::Builder builder(TestName());
90   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
91   HloInstruction* param = builder.AddInstruction(
92       HloInstruction::CreateParameter(0, scalar_shape, "param"));
93   builder.AddInstruction(
94       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
95   auto module = CreateUnverifiedModule();
96   HloComputation* computation = module->AddEntryComputation(builder.Build());
97 
98   TF_ASSERT_OK(verifier().Run(module.get()).status());
99 
100   computation->set_parent(nullptr);
101 
102   auto status = verifier().Run(module.get()).status();
103   ASSERT_FALSE(status.ok());
104   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
105 }
106 
TEST_F(HloVerifierTest,DifferentOperandParents)107 TEST_F(HloVerifierTest, DifferentOperandParents) {
108   HloComputation::Builder builder(TestName());
109   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
110   HloInstruction* param = builder.AddInstruction(
111       HloInstruction::CreateParameter(0, scalar_shape, "param"));
112   HloInstruction* negate = builder.AddInstruction(
113       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
114   auto module = CreateUnverifiedModule();
115   module->AddEntryComputation(builder.Build());
116 
117   HloComputation::Builder emb_builder(TestName());
118   HloInstruction* emb_param = emb_builder.AddInstruction(
119       HloInstruction::CreateParameter(0, scalar_shape, "param"));
120   module->AddEmbeddedComputation(emb_builder.Build());
121 
122   TF_ASSERT_OK(verifier().Run(module.get()).status());
123   TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
124 
125   auto status = verifier().Run(module.get()).status();
126   ASSERT_FALSE(status.ok());
127   EXPECT_THAT(status.error_message(),
128               HasSubstr("is in a different computation"));
129 }
130 
TEST_F(HloVerifierTest,ResetsShapeVerifierState)131 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
132   HloComputation::Builder builder(TestName());
133   Shape s1 = ShapeUtil::MakeShape(F32, {1});
134   Shape s2 = ShapeUtil::MakeShape(F32, {2});
135 
136   HloInstruction* param =
137       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
138 
139   // Create an add instruction with the incorrect shape.
140   HloInstruction* add = builder.AddInstruction(
141       HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
142 
143   // In order to trigger the bug we're checking for, the instruction with the
144   // bad shape can't be the root of the computation.
145   builder.AddInstruction(
146       HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
147 
148   auto module = CreateUnverifiedModule();
149   module->AddEntryComputation(builder.Build());
150 
151   // Run the verifier twice.  It should fail both times, because it shouldn't
152   // carry state in its DFS visitor between runs.
153   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
154   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
155 }
156 
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)157 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
158   const char* const hlo_string = R"(
159   HloModule Module
160 
161   callme {
162     ROOT param = (s32[], f32[4]) parameter(0)
163   }
164 
165   ENTRY entry {
166     p0 = (f32[4], s32[]) parameter(0)
167     ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
168   }
169   )";
170   TF_ASSERT_OK_AND_ASSIGN(auto module,
171                           ParseAndReturnUnverifiedModule(hlo_string));
172 
173   auto status = verifier().Run(module.get()).status();
174   ASSERT_FALSE(status.ok());
175   EXPECT_THAT(status.error_message(),
176               HasSubstr("shape does not match parameter"));
177 }
178 
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)179 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
180   const char* const hlo_string = R"(
181   HloModule Module
182 
183   true_branch {
184     tparam = (s32[], f32[4]) parameter(0)
185     ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
186   }
187 
188   false_branch {
189     fparam = (s32[], f32[4]) parameter(0)
190     ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
191   }
192 
193   ENTRY entry {
194     p0 = (f32[4], s32[]) parameter(0)
195     constant = pred[] constant(true)
196     ROOT conditional = f32[4] conditional(constant, p0, p0),
197       true_computation=true_branch, false_computation=false_branch
198   }
199   )";
200   TF_ASSERT_OK_AND_ASSIGN(auto module,
201                           ParseAndReturnUnverifiedModule(hlo_string));
202 
203   auto status = verifier().Run(module.get()).status();
204   ASSERT_FALSE(status.ok());
205   EXPECT_THAT(status.error_message(),
206               HasSubstr("shape does not match parameter"));
207 }
208 
TEST_F(HloVerifierTest,CheckConditionalBranchIndexOperandShape)209 TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
210   const char* const hlo_string = R"(
211   HloModule Module
212 
213   branch0 {
214     tparam = f32[4] parameter(0)
215     ROOT tgte1 = f32[4] ceil(tparam)
216   }
217 
218   branch1 {
219     fparam = f32[4] parameter(0)
220     ROOT fgte1 = f32[4] floor(fparam)
221   }
222 
223   branch2 {
224     sparam = f32[4] parameter(0)
225     ROOT sgte1 = f32[4] ceil(sparam)
226   }
227 
228   ENTRY entry {
229     p0 = f32[4] parameter(0)
230     b0 = s32[] parameter(1)
231     ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
232       branch_computations={branch0, branch1, branch2}
233   }
234   )";
235   TF_ASSERT_OK_AND_ASSIGN(auto module,
236                           ParseAndReturnUnverifiedModule(hlo_string));
237   auto status = verifier().Run(module.get()).status();
238 
239   HloInstruction* condition = FindInstruction(module.get(), "b0");
240   *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
241   status = verifier().Run(module.get()).status();
242   ASSERT_FALSE(status.ok());
243   EXPECT_THAT(
244       status.error_message(),
245       HasSubstr(
246           "first operand of indexed conditional must be a scalar of S32"));
247 
248   *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
249   status = verifier().Run(module.get()).status();
250   ASSERT_FALSE(status.ok());
251   EXPECT_THAT(status.error_message(),
252               HasSubstr("first operand of conditional must be a scalar"));
253 }
254 
TEST_F(HloVerifierTest,RngOpnd0NotScalar)255 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
256   const char* const hlo_string = R"(
257   HloModule Module
258 
259   ENTRY RngOpnd0NotScalar {
260    constant.0 = f32[] constant(0)
261    constant.1 = f16[2] constant({1, 3})
262    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
263     distribution=rng_uniform
264   }
265   )";
266   TF_ASSERT_OK_AND_ASSIGN(auto module,
267                           ParseAndReturnUnverifiedModule(hlo_string));
268 
269   auto status = verifier().Run(module.get()).status();
270   ASSERT_FALSE(status.ok());
271   EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
272 }
273 
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)274 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
275   const char* const hlo_string = R"(
276   HloModule Module
277 
278   ENTRY RngOperandElementTypesNotMatch {
279    constant.0 = f32[] constant(0)
280    constant.1 = f16[] constant(1)
281    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
282     distribution=rng_normal
283   }
284   )";
285   TF_ASSERT_OK_AND_ASSIGN(auto module,
286                           ParseAndReturnUnverifiedModule(hlo_string));
287 
288   auto status = verifier().Run(module.get()).status();
289   ASSERT_FALSE(status.ok());
290   EXPECT_THAT(status.error_message(),
291               HasSubstr("Expected compatible element types"));
292 }
293 
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)294 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
295   const char* const hlo_string = R"(
296   HloModule Module
297 
298   ENTRY RngResultElementTypeNotMatch {
299    constant.0 = f32[] constant(0)
300    constant.1 = f32[] constant(1)
301    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
302     distribution=rng_normal
303   }
304   )";
305   TF_ASSERT_OK_AND_ASSIGN(auto module,
306                           ParseAndReturnUnverifiedModule(hlo_string));
307 
308   auto status = verifier().Run(module.get()).status();
309   ASSERT_FALSE(status.ok());
310   EXPECT_THAT(status.error_message(),
311               HasSubstr("Expected compatible element types"));
312 }
313 
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)314 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
315   const char* const hlo_string = R"(
316   HloModule Module
317 
318   ENTRY RngResultElementTypeNotMatch {
319    constant.0 = f32[] constant(0)
320    constant.1 = f32[] constant(1)
321    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
322     distribution=rng_normal
323   }
324   )";
325   TF_ASSERT_OK_AND_ASSIGN(auto module,
326                           ParseAndReturnVerifiedModule(hlo_string));
327 
328   auto status = verifier().Run(module.get()).status();
329   ASSERT_TRUE(status.ok());
330 }
331 
TEST_F(HloVerifierTest,RngElementTypeNotSupported)332 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
333   const char* const hlo_string = R"(
334   HloModule Module
335 
336   ENTRY RngElementTypeNotSupported {
337    constant.0 = s32[] constant(0)
338    constant.1 = s32[] constant(1)
339    ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
340     distribution=rng_normal
341   }
342   )";
343   TF_ASSERT_OK_AND_ASSIGN(auto module,
344                           ParseAndReturnUnverifiedModule(hlo_string));
345 
346   auto status = verifier().Run(module.get()).status();
347   ASSERT_FALSE(status.ok());
348   EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
349 }
350 
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)351 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
352   // This testcase can't be written using textual HLO, because it doesn't parse
353   // negative interior padding.  That's probably a feature.  :)
354   HloComputation::Builder builder(TestName());
355   HloInstruction* param =
356       builder.AddInstruction(HloInstruction::CreateParameter(
357           0, ShapeUtil::MakeShape(F32, {100}), "param"));
358   PaddingConfig padding_config;
359   padding_config.add_dimensions()->set_interior_padding(-1);
360   builder.AddInstruction(HloInstruction::CreatePad(
361       ShapeUtil::MakeShape(F32, {100}), param,
362       builder.AddInstruction(
363           HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
364       padding_config));
365 
366   auto module = CreateUnverifiedModule();
367   module->AddEntryComputation(builder.Build());
368 
369   auto status = verifier().Run(module.get()).status();
370   ASSERT_FALSE(status.ok());
371   EXPECT_THAT(status.error_message(),
372               HasSubstr("Interior padding cannot be negative"));
373 }
374 
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)375 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
376   // This testcase can't be written using textual HLO, because it doesn't parse
377   // negative interior padding.  That's probably a feature.  :)
378   HloComputation::Builder builder(TestName());
379   HloInstruction* param =
380       builder.AddInstruction(HloInstruction::CreateParameter(
381           0, ShapeUtil::MakeShape(F32, {100}), "param"));
382   PaddingConfig padding_config;
383   padding_config.add_dimensions()->set_interior_padding(-1);
384   builder.AddInstruction(HloInstruction::CreatePad(
385       ShapeUtil::MakeShape(F32, {100}), param,
386       builder.AddInstruction(
387           HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
388       padding_config));
389 
390   auto module = CreateUnverifiedModule();
391   module->AddEntryComputation(builder.Build());
392 
393   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
394               HasSubstr("Interior padding cannot be negative"));
395 }
396 
397 // Simple module containing a convolution as the root.
398 static const char* const kConvHloString = R"(
399 HloModule module
400 ENTRY entry_computation {
401   param0 = f16[128,128,56,56] parameter(0)
402   param1 = f16[3,3,128,128] parameter(1)
403   zero_f16 = f16[] constant(0)
404   ROOT conv = f16[128,128,28,28] convolution(param0, param1),
405     window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
406 })";
407 
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)408 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
409   TF_ASSERT_OK_AND_ASSIGN(auto module,
410                           ParseAndReturnUnverifiedModule(kConvHloString));
411   auto* conv = module->entry_computation()->root_instruction();
412   Window w = conv->window();
413   w.mutable_dimensions(0)->set_window_dilation(-1);
414   conv->set_window(w);
415 
416   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
417               HasSubstr("non-positive window dilation factor"));
418 }
419 
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)420 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
421   TF_ASSERT_OK_AND_ASSIGN(auto module,
422                           ParseAndReturnUnverifiedModule(kConvHloString));
423   auto* conv = module->entry_computation()->root_instruction();
424   Window w = conv->window();
425   w.mutable_dimensions(0)->set_base_dilation(-1);
426   conv->set_window(w);
427 
428   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
429               HasSubstr("non-positive base area dilation factor"));
430 }
431 
432 static const char* const kAddWithLayoutChangeHlo = R"(
433    HloModule AddWithLayoutChange
434     ENTRY AddWithLayoutChange {
435       par0 = f32[3,4]{1,0} parameter(0)
436       par1 = f32[3,4]{0,1} parameter(1)
437       ROOT add0 = f32[3,4]{1,0} add(par0,par1)
438     }
439   )";
440 
TEST_F(HloVerifierTest,AddWithLayoutChange)441 TEST_F(HloVerifierTest, AddWithLayoutChange) {
442   TF_ASSERT_OK_AND_ASSIGN(
443       auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
444   auto status = verifier().Run(module.get()).status();
445   ASSERT_TRUE(status.ok());
446 }
447 
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)448 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
449   const char* const kScalarIndexDynamicSlice = R"(
450     HloModule DynamicSlice_module
451 
452     ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
453       %original_parameter = s32[2,2,258] parameter(0)
454       %constant = s32[] constant(0)
455       %start_index = s32[] parameter(1)
456       ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
457     }
458   )";
459 
460   HloModuleConfig config;
461   DebugOptions debug_options = config.debug_options();
462   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
463   config.set_debug_options(debug_options);
464 
465   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
466                                            kScalarIndexDynamicSlice, config));
467   auto status = verifier().Run(module.get()).status();
468   ASSERT_TRUE(status.ok());
469 }
470 
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)471 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
472   const char* const kScalarIndexDynamicSlice = R"(
473     HloModule DynamicUpdateSlice_module
474 
475     ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
476       %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
477       %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
478       %start_index.0 = s32[] parameter(2)
479       %start_index.1 = s32[] parameter(3)
480       %start_index.2 = s32[] parameter(4)
481       %start_index.3 = s32[] parameter(5)
482       ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
483     }
484   )";
485 
486   HloModuleConfig config;
487   DebugOptions debug_options = config.debug_options();
488   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
489   config.set_debug_options(debug_options);
490 
491   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
492                                            kScalarIndexDynamicSlice, config));
493   auto status = verifier().Run(module.get()).status();
494   ASSERT_TRUE(status.ok());
495 }
496 
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)497 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
498   TF_ASSERT_OK_AND_ASSIGN(
499       auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
500   auto status = verifier().Run(module.get()).status();
501   ASSERT_FALSE(status.ok());
502   EXPECT_THAT(status.error_message(),
503               HasSubstr("Instruction shouldn't change layouts"));
504 }
505 
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)506 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
507   const char* const kSliceWithLayoutChangeHlo = R"(
508    HloModule SliceWithLayoutChange
509     ENTRY SliceWithLayoutChange {
510       par0 = f32[4,5]{0,1} parameter(0)
511       par1 = s32[] parameter(1)
512       par2 = s32[] parameter(2)
513       ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
514         dynamic_slice_sizes={3,4}
515     }
516   )";
517   TF_ASSERT_OK_AND_ASSIGN(
518       auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
519   auto status = verifier().Run(module.get()).status();
520   ASSERT_FALSE(status.ok());
521   EXPECT_THAT(status.error_message(),
522               HasSubstr("Instruction shouldn't change layouts"));
523 }
524 
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)525 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
526   const char* const kConcatWithLayoutChangeHlo = R"(
527    HloModule ConcatWithLayoutChange
528    ENTRY ConcatWithLayoutChange {
529       par0 = f32[3,5]{0,1} parameter(0)
530       par1 = f32[3,3]{1,0} parameter(1)
531       ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
532         dimensions={1}
533    }
534   )";
535   TF_ASSERT_OK_AND_ASSIGN(
536       auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
537   auto status = verifier().Run(module.get()).status();
538   ASSERT_FALSE(status.ok());
539   EXPECT_THAT(status.error_message(),
540               HasSubstr("Instruction shouldn't change layouts"));
541 }
542 
TEST_F(HloVerifierTest,BitcastCanNotChangeElementType)543 TEST_F(HloVerifierTest, BitcastCanNotChangeElementType) {
544   const char* const hlo_string = R"(
545   HloModule Module
546 
547   ENTRY BitcastCanNotChangeElementType {
548    constant.0 = f32[2] constant({0.0, 0.0})
549    ROOT bitcast = s32[2] bitcast(constant.0)
550   }
551   )";
552   TF_ASSERT_OK_AND_ASSIGN(auto module,
553                           ParseAndReturnUnverifiedModule(hlo_string));
554 
555   auto status = verifier().Run(module.get()).status();
556   ASSERT_FALSE(status.ok());
557   EXPECT_THAT(status.error_message(),
558               HasSubstr("Bitcast can not change the element type"));
559 }
560 
TEST_F(HloVerifierTestLayoutSensitive,BitcastNeedsSameNumberOfElements)561 TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
562   const char* const hlo_string = R"(
563   HloModule Module
564 
565   ENTRY BitcastNeedsToBeNoOp {
566    constant.0 = f32[2] constant({0.0, 0.0})
567    ROOT bitcast = f32[3] bitcast(constant.0)
568   }
569   )";
570   TF_ASSERT_OK_AND_ASSIGN(auto module,
571                           ParseAndReturnUnverifiedModule(hlo_string));
572 
573   auto status = verifier().Run(module.get()).status();
574   ASSERT_FALSE(status.ok());
575   EXPECT_THAT(status.error_message(),
576               HasSubstr("Bitcast cannot have different shape sizes of output "
577                         "(12) and operand (8)"));
578 }
579 
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)580 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
581   const char* const hlo_string = R"(
582   HloModule Module
583 
584   ENTRY SelectMixedPrecisionNotAllowed {
585    p0 = pred[32] parameter(0)
586    p1 = f32[32] parameter(1)
587    p2 = bf16[32] parameter(2)
588    ROOT select = f32[32] select(p0, p1, p2)
589   }
590   )";
591   TF_ASSERT_OK_AND_ASSIGN(auto module,
592                           ParseAndReturnUnverifiedModule(hlo_string));
593 
594   auto status = verifier().Run(module.get()).status();
595   ASSERT_FALSE(status.ok());
596   EXPECT_THAT(status.error_message(),
597               HasSubstr("Seen floating point types of different precisions"));
598 }
599 
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)600 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
601   const char* const hlo_string = R"(
602   HloModule Module
603 
604   ENTRY SelectMixedPrecisionAllowed {
605    p0 = pred[32] parameter(0)
606    p1 = f32[32] parameter(1)
607    p2 = bf16[32] parameter(2)
608    ROOT select = f32[32] select(p0, p1, p2)
609   }
610   )";
611   TF_ASSERT_OK_AND_ASSIGN(auto module,
612                           ParseAndReturnVerifiedModule(hlo_string));
613 
614   auto status = verifier().Run(module.get()).status();
615   ASSERT_TRUE(status.ok());
616 }
617 
TEST_F(HloVerifierTest,SelectTupleNotAllowed)618 TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
619   const char* const hlo_string = R"(
620   HloModule Module
621 
622   ENTRY SelectWithTuple {
623     p0 = (f32[], f32[]) parameter(0)
624     p1 = (f32[], f32[]) parameter(1)
625     p2 = pred[] parameter(2)
626     ROOT select = (f32[], f32[]) select(p2, p0, p1)
627   }
628   )";
629   TF_ASSERT_OK_AND_ASSIGN(auto module,
630                           ParseAndReturnUnverifiedModule(hlo_string));
631 
632   auto status = verifier().Run(module.get()).status();
633   ASSERT_FALSE(status.ok());
634   EXPECT_THAT(status.error_message(),
635               HasSubstr("Expected array argument for select"));
636 }
637 
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDone)638 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
639   const char* const hlo_string = R"(
640   HloModule Module
641 
642   ENTRY CopyStartAndCopyDone {
643     p0 = f32[2,3]{1,0:S(1)} parameter(0)
644     copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
645     ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
646   }
647   )";
648   TF_ASSERT_OK_AND_ASSIGN(auto module,
649                           ParseAndReturnVerifiedModule(hlo_string));
650 
651   auto status = verifier().Run(module.get()).status();
652   ASSERT_TRUE(status.ok());
653 }
654 
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDoneWrongLayout)655 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
656   const char* const hlo_string = R"(
657   HloModule Module
658 
659   ENTRY CopyStartAndCopyDone {
660     p0 = f32[2,3]{1,0:S(1)} parameter(0)
661     copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
662     ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
663   }
664   )";
665   TF_ASSERT_OK_AND_ASSIGN(auto module,
666                           ParseAndReturnUnverifiedModule(hlo_string));
667 
668   auto status = verifier().Run(module.get()).status();
669   ASSERT_FALSE(status.ok());
670   EXPECT_THAT(status.error_message(),
671               HasSubstr("Expected instruction to have shape equal to"));
672 }
673 
TEST_F(HloVerifierTest,CopyStartAndCopyDoneWrongType)674 TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
675   const char* const hlo_string = R"(
676   HloModule Module
677 
678   ENTRY CopyStartAndCopyDone {
679     p0 = f32[2,3] parameter(0)
680     copy-start = f32[2,3] copy-start(p0)
681     ROOT copy-done = f32[2,3] copy-done(copy-start)
682   }
683   )";
684   TF_ASSERT_OK_AND_ASSIGN(auto module,
685                           ParseAndReturnUnverifiedModule(hlo_string));
686 
687   auto status = verifier().Run(module.get()).status();
688   ASSERT_FALSE(status.ok());
689   EXPECT_THAT(status.error_message(),
690               HasSubstr("Expected instruction to have shape equal to "
691                         "(f32[2,3], f32[2,3], u32[])"));
692 }
693 
TEST_F(HloVerifierTest,CopyStartMultipleCopyDone)694 TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
695   const char* const hlo_string = R"(
696   HloModule Module
697 
698   ENTRY CopyStartAndCopyDone {
699     p0 = f32[2,3] parameter(0)
700     copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
701     copy-done.1 = f32[2,3] copy-done(copy-start)
702     copy-done.2 = f32[2,3] copy-done(copy-start)
703     ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
704   }
705   )";
706   TF_ASSERT_OK_AND_ASSIGN(auto module,
707                           ParseAndReturnUnverifiedModule(hlo_string));
708 
709   auto status = verifier().Run(module.get()).status();
710   ASSERT_FALSE(status.ok());
711   EXPECT_THAT(
712       status.error_message(),
713       HasSubstr("CopyStart instruction requires one consumer, found 2"));
714 }
715 
TEST_F(HloVerifierTest,CopyDoneNoCopyStart)716 TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
717   const char* const hlo_string = R"(
718   HloModule Module
719 
720   ENTRY CopyStartAndCopyDone {
721     p0 = f32[2,3] parameter(0)
722     p1 = u32[] parameter(1)
723     tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
724     ROOT copy-done = f32[2,3] copy-done(tuple)
725   }
726   )";
727   TF_ASSERT_OK_AND_ASSIGN(auto module,
728                           ParseAndReturnUnverifiedModule(hlo_string));
729 
730   auto status = verifier().Run(module.get()).status();
731   ASSERT_FALSE(status.ok());
732   EXPECT_THAT(status.error_message(),
733               HasSubstr("The operand of a CopyDone instruction needs to be "
734                         "CopyStart, found tuple"));
735 }
736 
TEST_F(HloVerifierTest,IotaNonArrayResult)737 TEST_F(HloVerifierTest, IotaNonArrayResult) {
738   const char* const hlo_string = R"(
739   HloModule IotaTupleResult
740 
741   ENTRY  kernelEntry {
742     ROOT iota = () iota(), iota_dimension=24
743   }
744   )";
745 
746   TF_ASSERT_OK_AND_ASSIGN(auto module,
747                           ParseAndReturnUnverifiedModule(hlo_string));
748 
749   auto status = verifier().Run(module.get()).status();
750   ASSERT_FALSE(status.ok());
751   EXPECT_THAT(status.error_message(),
752               HasSubstr("does not support non-array result"));
753 }
754 
TEST_F(HloVerifierTest,IotaNegativeDimension)755 TEST_F(HloVerifierTest, IotaNegativeDimension) {
756   const char* const hlo_string = R"(
757   HloModule IotaTupleResult
758 
759   ENTRY  kernelEntry {
760     ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
761   }
762   )";
763 
764   TF_ASSERT_OK_AND_ASSIGN(auto module,
765                           ParseAndReturnUnverifiedModule(hlo_string));
766 
767   auto status = verifier().Run(module.get()).status();
768   ASSERT_FALSE(status.ok());
769   EXPECT_THAT(status.error_message(), HasSubstr("negative"));
770 }
771 
TEST_F(HloVerifierTest,IotaPredResultNotAllowed)772 TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
773   const char* const hlo_string = R"(
774   HloModule IotaPredResult
775 
776   ENTRY  kernelEntry {
777     ROOT iota = pred[128] iota(), iota_dimension=0
778   }
779   )";
780 
781   TF_ASSERT_OK_AND_ASSIGN(auto module,
782                           ParseAndReturnUnverifiedModule(hlo_string));
783 
784   auto status = verifier().Run(module.get()).status();
785   ASSERT_FALSE(status.ok());
786   EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
787 }
788 
789 static const char* const kMapOperandComputationMismatchHlo = R"(
790   HloModule MapOperandComputationMismatch
791 
792   Computation {
793     param0 = f32[] parameter(0)
794     constant = f32[] constant(1)
795     ROOT add = f32[] add(param0, constant)
796   }
797 
798   ENTRY kernelEntry {
799   param = f64[] parameter(0)
800   ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
801 })";
802 
TEST_F(HloVerifierTest,MapOperandComputationMismatch)803 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
804   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
805                                            kMapOperandComputationMismatchHlo));
806   auto status = verifier().Run(module.get()).status();
807   ASSERT_FALSE(status.ok());
808   EXPECT_THAT(
809       status.error_message(),
810       HasSubstr(
811           "Shape mismatch between to_apply computation parameter and operand"));
812 }
813 
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)814 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
815   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
816                                            kMapOperandComputationMismatchHlo));
817   auto status = verifier().Run(module.get()).status();
818   ASSERT_TRUE(status.ok());
819 }
820 
821 static const char* const kReduceOperandComputationMismatchHlo = R"(
822   HloModule ReduceOperandComputationMismatch
823   computation {
824     x = f32[] parameter(0)
825     y = f32[] parameter(1)
826     ROOT add = f32[] add(x, y)
827   }
828 
829   ENTRY kernelEntry {
830     arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
831     constant = f16[] constant(0)
832     reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
833   })";
834 
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)835 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
836   TF_ASSERT_OK_AND_ASSIGN(
837       auto module,
838       ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
839   auto status = verifier().Run(module.get()).status();
840   ASSERT_FALSE(status.ok());
841   EXPECT_THAT(status.error_message(),
842               HasSubstr("Expected instruction to have shape equal to f32[64]"));
843 }
844 
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)845 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
846   TF_ASSERT_OK_AND_ASSIGN(
847       auto module,
848       ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
849   auto status = verifier().Run(module.get()).status();
850   ASSERT_TRUE(status.ok());
851 }
852 
ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups)853 string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
854   std::vector<string> replica_group_strs;
855   for (const auto& g : replica_groups) {
856     replica_group_strs.push_back(
857         absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
858   }
859   return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
860 }
861 
MakeAllReduceComputation(std::vector<std::vector<int64>> replica_groups)862 StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
863     std::vector<std::vector<int64>> replica_groups) {
864   const char* kTemplate = R"(
865   HloModule test
866   add {
867     x = f32[] parameter(0)
868     y = f32[] parameter(1)
869     ROOT add = f32[] add(x, y)
870   }
871   ENTRY entry {
872     p = f32[128]{0} parameter(0)
873     crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
874   })";
875   return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
876       kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
877 }
878 
TEST_F(HloVerifierTest,AllReduce_NoReplicaGroupsOK)879 TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
880   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
881   TF_ASSERT_OK(verifier().Run(module.get()).status());
882 }
883 
TEST_F(HloVerifierTest,AllReduce_DifferentGroupSizesOk)884 TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
885   TF_ASSERT_OK_AND_ASSIGN(auto module,
886                           MakeAllReduceComputation({{0}, {1, 3}, {2}}));
887   TF_ASSERT_OK(verifier().Run(module.get()).status());
888 }
889 
TEST_F(HloVerifierTest,AllReduce_EmptyReplicaGroup)890 TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
891   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
892   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
893               HasSubstr("empty replica group"));
894 }
895 
TEST_F(HloVerifierTest,AllReduce_RepeatedReplicaId)896 TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
897   TF_ASSERT_OK_AND_ASSIGN(auto module,
898                           MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
899   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
900               HasSubstr("Replica 0 is repeated"));
901 }
902 
TEST_F(HloVerifierTest,AllReduce_MissingReplicaId)903 TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
904   TF_ASSERT_OK_AND_ASSIGN(auto module,
905                           MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
906   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
907               HasSubstr("Replica 4 is not named"));
908 }
909 
MakeAllToAllComputation(std::vector<std::vector<int64>> replica_groups)910 StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
911     std::vector<std::vector<int64>> replica_groups) {
912   const char* kTemplate = R"(
913   HloModule test
914   add {
915     x = f32[] parameter(0)
916     y = f32[] parameter(1)
917     ROOT add = f32[] add(x, y)
918   }
919   ENTRY entry {
920     p0 = f32[128]{0} parameter(0)
921     p1 = f32[128]{0} parameter(1)
922     a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
923   })";
924   return ParseAndReturnUnverifiedModule(absl::StrReplaceAll(
925       kTemplate, {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)}}));
926 }
927 
TEST_F(HloVerifierTest,AllToAll_NoReplicaGroupsOK)928 TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
929   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
930   TF_ASSERT_OK(verifier().Run(module.get()).status());
931 }
932 
TEST_F(HloVerifierTest,AllToAll_EmptyReplicaGroup)933 TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
934   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
935   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
936               HasSubstr("empty replica group"));
937 }
938 
TEST_F(HloVerifierTest,AllToAll_RepeatedReplicaId)939 TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
940   TF_ASSERT_OK_AND_ASSIGN(auto module,
941                           MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
942   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
943               HasSubstr("Replica 0 is repeated"));
944 }
945 
TEST_F(HloVerifierTest,AllToAll_MissingReplicaId)946 TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
947   TF_ASSERT_OK_AND_ASSIGN(auto module,
948                           MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
949   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
950               HasSubstr("Replica 4 is not named"));
951 }
952 
TEST_F(HloVerifierTest,AllToAll_WrongNumberOfReplicasInGroup)953 TEST_F(HloVerifierTest, AllToAll_WrongNumberOfReplicasInGroup) {
954   TF_ASSERT_OK_AND_ASSIGN(auto module,
955                           MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
956   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
957               HasSubstr("Replica group has size 1"));
958 }
959 
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTwice)960 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
961   const char* const kModuleStr = R"(
962   HloModule test
963   ENTRY entry {
964     p0 = f32[128] parameter(0)
965     ROOT permute = f32[128] collective-permute(p0),
966       source_target_pairs={{0,1}, {0,2}, {1,0}}
967   }
968   )";
969   TF_ASSERT_OK_AND_ASSIGN(auto module,
970                           ParseAndReturnUnverifiedModule(kModuleStr));
971   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
972               HasSubstr("Source 0 appears more than once"));
973 }
974 
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTwice)975 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
976   const char* const kModuleStr = R"(
977   HloModule test
978   ENTRY entry {
979     p0 = f32[128] parameter(0)
980     ROOT permute = f32[128] collective-permute(p0),
981       source_target_pairs={{0,2}, {1,2}, {2,0}}
982   }
983   )";
984   TF_ASSERT_OK_AND_ASSIGN(auto module,
985                           ParseAndReturnUnverifiedModule(kModuleStr));
986   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
987               HasSubstr("Target 2 appears more than once"));
988 }
989 
TEST_F(HloVerifierTest,FusionShapeVerifier)990 TEST_F(HloVerifierTest, FusionShapeVerifier) {
991   const char* const kModuleStr = R"(
992   HloModule test
993 
994   fused_computation {
995     ROOT p0 = f32[10,10] parameter(0)
996   }
997 
998   ENTRY entry {
999     p0 = f32[10,10] parameter(0)
1000     ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
1001   }
1002   )";
1003   TF_ASSERT_OK_AND_ASSIGN(auto module,
1004                           ParseAndReturnUnverifiedModule(kModuleStr));
1005   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1006               HasSubstr("Fused computation shape"));
1007 }
1008 
TEST_F(HloVerifierTest,AllReduceVerifier)1009 TEST_F(HloVerifierTest, AllReduceVerifier) {
1010   const char* const kModuleStr = R"(
1011   HloModule test
1012 
1013   add {
1014     lhs = f32[] parameter(0)
1015     rhs = f32[] parameter(1)
1016     ROOT add = f32[] add(lhs, rhs)
1017   }
1018 
1019   ENTRY entry {
1020     input = f32[8,12]{0,1} parameter(0)
1021     crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
1022     crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
1023       constrain_layout=true
1024     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
1025   }
1026   )";
1027   TF_ASSERT_OK_AND_ASSIGN(auto module,
1028                           ParseAndReturnUnverifiedModule(kModuleStr));
1029   EXPECT_THAT(
1030       verifier().Run(module.get()).status().error_message(),
1031       HasSubstr("mix of layout constrained and unconstrained AllReduce"));
1032 }
1033 
TEST_F(HloVerifierTest,ChannelVerifier)1034 TEST_F(HloVerifierTest, ChannelVerifier) {
1035   const char* const kModuleStr = R"(
1036   HloModule test
1037 
1038   add {
1039     lhs = f32[] parameter(0)
1040     rhs = f32[] parameter(1)
1041     ROOT add = f32[] add(lhs, rhs)
1042   }
1043 
1044   ENTRY entry {
1045     %input = f32[8,12] parameter(0)
1046     %token0 = token[] after-all()
1047     %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
1048     %send-done = token[] send-done(%send), channel_id=1
1049     %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1050       channel_id=1
1051     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
1052   }
1053   )";
1054   TF_ASSERT_OK_AND_ASSIGN(auto module,
1055                           ParseAndReturnUnverifiedModule(kModuleStr));
1056   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1057               HasSubstr("used for different types of channel instructions"));
1058 }
1059 
TEST_F(HloVerifierTest,CollectiveChannelVerifier)1060 TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
1061   const char* const kModuleStr = R"(
1062   HloModule test
1063 
1064   add {
1065     lhs = f32[] parameter(0)
1066     rhs = f32[] parameter(1)
1067     ROOT add = f32[] add(lhs, rhs)
1068   }
1069 
1070   ENTRY entry {
1071     %input = f32[8,12] parameter(0)
1072     %permute = f32[8,12] collective-permute(%input),
1073       source_target_pairs={{0,1},{1,0}}, channel_id=1
1074     %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1075       channel_id=1
1076     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
1077   }
1078   )";
1079   TF_ASSERT_OK_AND_ASSIGN(auto module,
1080                           ParseAndReturnUnverifiedModule(kModuleStr));
1081   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1082               HasSubstr("used for different types of channel instructions"));
1083 }
1084 
1085 }  // namespace
1086 }  // namespace xla
1087