• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_replace.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/layout_assignment.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla.pb.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 using ::testing::HasSubstr;
40 
CreateUnverifiedModule()41 std::unique_ptr<HloModule> CreateUnverifiedModule() {
42   return absl::make_unique<HloModule>("module", HloModuleConfig());
43 }
44 
45 // This class cannot be converted to use HloTestBase. It explicitly
46 // uses HloTestBase to create and test malformed HLOs.
47 class HloVerifierTest : public HloTestBase {
48  public:
HloVerifierTest()49   HloVerifierTest()
50       : HloTestBase(/*verifier_layout_sensitive=*/false,
51                     /*allow_mixed_precision_in_hlo_verifier=*/false) {}
52 };
53 
54 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
55  public:
HloVerifierTestAllowMixedPrecision()56   HloVerifierTestAllowMixedPrecision()
57       : HloTestBase(/*verifier_layout_sensitive=*/false,
58                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
59 };
60 
61 class HloVerifierTestLayoutSensitive : public HloTestBase {
62  public:
HloVerifierTestLayoutSensitive()63   HloVerifierTestLayoutSensitive()
64       : HloTestBase(/*verifier_layout_sensitive=*/true,
65                     /*allow_mixed_precision_in_hlo_verifier=*/false,
66                     LayoutAssignment::InstructionCanChangeLayout) {}
67 };
68 
TEST_F(HloVerifierTest,NullInstructionParent)69 TEST_F(HloVerifierTest, NullInstructionParent) {
70   HloComputation::Builder builder(TestName());
71   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
72   HloInstruction* param = builder.AddInstruction(
73       HloInstruction::CreateParameter(0, scalar_shape, "param"));
74   HloInstruction* negate = builder.AddInstruction(
75       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
76   auto module = CreateUnverifiedModule();
77   module->AddEntryComputation(builder.Build());
78 
79   TF_ASSERT_OK(verifier().Run(module.get()).status());
80 
81   negate->set_parent(nullptr);
82 
83   auto status = verifier().Run(module.get()).status();
84   ASSERT_FALSE(status.ok());
85   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
86 }
87 
TEST_F(HloVerifierTest,NullComputationParent)88 TEST_F(HloVerifierTest, NullComputationParent) {
89   HloComputation::Builder builder(TestName());
90   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
91   HloInstruction* param = builder.AddInstruction(
92       HloInstruction::CreateParameter(0, scalar_shape, "param"));
93   builder.AddInstruction(
94       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
95   auto module = CreateUnverifiedModule();
96   HloComputation* computation = module->AddEntryComputation(builder.Build());
97 
98   TF_ASSERT_OK(verifier().Run(module.get()).status());
99 
100   computation->set_parent(nullptr);
101 
102   auto status = verifier().Run(module.get()).status();
103   ASSERT_FALSE(status.ok());
104   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
105 }
106 
TEST_F(HloVerifierTest,DifferentOperandParents)107 TEST_F(HloVerifierTest, DifferentOperandParents) {
108   HloComputation::Builder builder(TestName());
109   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
110   HloInstruction* param = builder.AddInstruction(
111       HloInstruction::CreateParameter(0, scalar_shape, "param"));
112   HloInstruction* negate = builder.AddInstruction(
113       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
114   auto module = CreateUnverifiedModule();
115   module->AddEntryComputation(builder.Build());
116 
117   HloComputation::Builder emb_builder(TestName());
118   HloInstruction* emb_param = emb_builder.AddInstruction(
119       HloInstruction::CreateParameter(0, scalar_shape, "param"));
120   module->AddEmbeddedComputation(emb_builder.Build());
121 
122   TF_ASSERT_OK(verifier().Run(module.get()).status());
123   TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
124 
125   auto status = verifier().Run(module.get()).status();
126   ASSERT_FALSE(status.ok());
127   EXPECT_THAT(status.error_message(),
128               HasSubstr("is in a different computation"));
129 }
130 
TEST_F(HloVerifierTest,ResetsShapeVerifierState)131 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
132   HloComputation::Builder builder(TestName());
133   Shape s1 = ShapeUtil::MakeShape(F32, {1});
134   Shape s2 = ShapeUtil::MakeShape(F32, {2});
135 
136   HloInstruction* param =
137       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
138 
139   // Create an add instruction with the incorrect shape.
140   HloInstruction* add = builder.AddInstruction(
141       HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
142 
143   // In order to trigger the bug we're checking for, the instruction with the
144   // bad shape can't be the root of the computation.
145   builder.AddInstruction(
146       HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
147 
148   auto module = CreateUnverifiedModule();
149   module->AddEntryComputation(builder.Build());
150 
151   // Run the verifier twice.  It should fail both times, because it shouldn't
152   // carry state in its DFS visitor between runs.
153   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
154   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
155 }
156 
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)157 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
158   const char* const hlo_string = R"(
159   HloModule Module
160 
161   callme {
162     ROOT param = (s32[], f32[4]) parameter(0)
163   }
164 
165   ENTRY entry {
166     p0 = (f32[4], s32[]) parameter(0)
167     ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
168   }
169   )";
170   TF_ASSERT_OK_AND_ASSIGN(auto module,
171                           ParseAndReturnUnverifiedModule(hlo_string));
172 
173   auto status = verifier().Run(module.get()).status();
174   ASSERT_FALSE(status.ok());
175   EXPECT_THAT(status.error_message(),
176               HasSubstr("shape does not match parameter"));
177 }
178 
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)179 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
180   const char* const hlo_string = R"(
181   HloModule Module
182 
183   true_branch {
184     tparam = (s32[], f32[4]) parameter(0)
185     ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
186   }
187 
188   false_branch {
189     fparam = (s32[], f32[4]) parameter(0)
190     ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
191   }
192 
193   ENTRY entry {
194     p0 = (f32[4], s32[]) parameter(0)
195     constant = pred[] constant(true)
196     ROOT conditional = f32[4] conditional(constant, p0, p0),
197       true_computation=true_branch, false_computation=false_branch
198   }
199   )";
200   TF_ASSERT_OK_AND_ASSIGN(auto module,
201                           ParseAndReturnUnverifiedModule(hlo_string));
202 
203   auto status = verifier().Run(module.get()).status();
204   ASSERT_FALSE(status.ok());
205   EXPECT_THAT(status.error_message(),
206               HasSubstr("shape does not match parameter"));
207 }
208 
TEST_F(HloVerifierTest,CheckConditionalBranchIndexOperandShape)209 TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
210   const char* const hlo_string = R"(
211   HloModule Module
212 
213   branch0 {
214     tparam = f32[4] parameter(0)
215     ROOT tgte1 = f32[4] ceil(tparam)
216   }
217 
218   branch1 {
219     fparam = f32[4] parameter(0)
220     ROOT fgte1 = f32[4] floor(fparam)
221   }
222 
223   branch2 {
224     sparam = f32[4] parameter(0)
225     ROOT sgte1 = f32[4] ceil(sparam)
226   }
227 
228   ENTRY entry {
229     p0 = f32[4] parameter(0)
230     b0 = s32[] parameter(1)
231     ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
232       branch_computations={branch0, branch1, branch2}
233   }
234   )";
235   TF_ASSERT_OK_AND_ASSIGN(auto module,
236                           ParseAndReturnUnverifiedModule(hlo_string));
237   auto status = verifier().Run(module.get()).status();
238 
239   HloInstruction* condition = FindInstruction(module.get(), "b0");
240   *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
241   status = verifier().Run(module.get()).status();
242   ASSERT_FALSE(status.ok());
243   EXPECT_THAT(
244       status.error_message(),
245       HasSubstr(
246           "first operand of indexed conditional must be a scalar of S32"));
247 
248   *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
249   status = verifier().Run(module.get()).status();
250   ASSERT_FALSE(status.ok());
251   EXPECT_THAT(status.error_message(),
252               HasSubstr("first operand of conditional must be a scalar"));
253 }
254 
TEST_F(HloVerifierTest,RngOpnd0NotScalar)255 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
256   const char* const hlo_string = R"(
257   HloModule Module
258 
259   ENTRY RngOpnd0NotScalar {
260    constant.0 = f32[] constant(0)
261    constant.1 = f16[2] constant({1, 3})
262    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
263     distribution=rng_uniform
264   }
265   )";
266   TF_ASSERT_OK_AND_ASSIGN(auto module,
267                           ParseAndReturnUnverifiedModule(hlo_string));
268 
269   auto status = verifier().Run(module.get()).status();
270   ASSERT_FALSE(status.ok());
271   EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
272 }
273 
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)274 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
275   const char* const hlo_string = R"(
276   HloModule Module
277 
278   ENTRY RngOperandElementTypesNotMatch {
279    constant.0 = f32[] constant(0)
280    constant.1 = f16[] constant(1)
281    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
282     distribution=rng_normal
283   }
284   )";
285   TF_ASSERT_OK_AND_ASSIGN(auto module,
286                           ParseAndReturnUnverifiedModule(hlo_string));
287 
288   auto status = verifier().Run(module.get()).status();
289   ASSERT_FALSE(status.ok());
290   EXPECT_THAT(status.error_message(),
291               HasSubstr("Expected compatible element types"));
292 }
293 
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)294 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
295   const char* const hlo_string = R"(
296   HloModule Module
297 
298   ENTRY RngResultElementTypeNotMatch {
299    constant.0 = f32[] constant(0)
300    constant.1 = f32[] constant(1)
301    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
302     distribution=rng_normal
303   }
304   )";
305   TF_ASSERT_OK_AND_ASSIGN(auto module,
306                           ParseAndReturnUnverifiedModule(hlo_string));
307 
308   auto status = verifier().Run(module.get()).status();
309   ASSERT_FALSE(status.ok());
310   EXPECT_THAT(status.error_message(),
311               HasSubstr("Expected compatible element types"));
312 }
313 
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)314 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
315   const char* const hlo_string = R"(
316   HloModule Module
317 
318   ENTRY RngResultElementTypeNotMatch {
319    constant.0 = f32[] constant(0)
320    constant.1 = f32[] constant(1)
321    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
322     distribution=rng_normal
323   }
324   )";
325   TF_ASSERT_OK_AND_ASSIGN(auto module,
326                           ParseAndReturnVerifiedModule(hlo_string));
327 
328   auto status = verifier().Run(module.get()).status();
329   ASSERT_TRUE(status.ok());
330 }
331 
TEST_F(HloVerifierTest,RngElementTypeNotSupported)332 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
333   const char* const hlo_string = R"(
334   HloModule Module
335 
336   ENTRY RngElementTypeNotSupported {
337    constant.0 = s32[] constant(0)
338    constant.1 = s32[] constant(1)
339    ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
340     distribution=rng_normal
341   }
342   )";
343   TF_ASSERT_OK_AND_ASSIGN(auto module,
344                           ParseAndReturnUnverifiedModule(hlo_string));
345 
346   auto status = verifier().Run(module.get()).status();
347   ASSERT_FALSE(status.ok());
348   EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
349 }
350 
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)351 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
352   // This testcase can't be written using textual HLO, because it doesn't parse
353   // negative interior padding.  That's probably a feature.  :)
354   HloComputation::Builder builder(TestName());
355   HloInstruction* param =
356       builder.AddInstruction(HloInstruction::CreateParameter(
357           0, ShapeUtil::MakeShape(F32, {100}), "param"));
358   PaddingConfig padding_config;
359   padding_config.add_dimensions()->set_interior_padding(-1);
360   builder.AddInstruction(HloInstruction::CreatePad(
361       ShapeUtil::MakeShape(F32, {100}), param,
362       builder.AddInstruction(
363           HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
364       padding_config));
365 
366   auto module = CreateUnverifiedModule();
367   module->AddEntryComputation(builder.Build());
368 
369   auto status = verifier().Run(module.get()).status();
370   ASSERT_FALSE(status.ok());
371   EXPECT_THAT(status.error_message(),
372               HasSubstr("Interior padding cannot be negative"));
373 }
374 
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)375 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
376   // This testcase can't be written using textual HLO, because it doesn't parse
377   // negative interior padding.  That's probably a feature.  :)
378   HloComputation::Builder builder(TestName());
379   HloInstruction* param =
380       builder.AddInstruction(HloInstruction::CreateParameter(
381           0, ShapeUtil::MakeShape(F32, {100}), "param"));
382   PaddingConfig padding_config;
383   padding_config.add_dimensions()->set_interior_padding(-1);
384   builder.AddInstruction(HloInstruction::CreatePad(
385       ShapeUtil::MakeShape(F32, {100}), param,
386       builder.AddInstruction(
387           HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
388       padding_config));
389 
390   auto module = CreateUnverifiedModule();
391   module->AddEntryComputation(builder.Build());
392 
393   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
394               HasSubstr("Interior padding cannot be negative"));
395 }
396 
TEST_F(HloVerifierTest,DotMixedPrecisionAllowed)397 TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
398   static const char* const kDotHloString = R"(
399 HloModule module
400 ENTRY entry_computation {
401   a = f32[2,10] parameter(0)
402   b = bf16[10,2] parameter(1)
403   ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
404 })";
405   TF_ASSERT_OK_AND_ASSIGN(auto module,
406                           ParseAndReturnVerifiedModule(kDotHloString));
407 
408   auto status = verifier().Run(module.get()).status();
409   EXPECT_TRUE(status.ok()) << status;
410 }
411 
412 // Simple module containing a convolution as the root.
413 static const char* const kConvHloString = R"(
414 HloModule module
415 ENTRY entry_computation {
416   param0 = f16[128,128,56,56] parameter(0)
417   param1 = f16[3,3,128,128] parameter(1)
418   zero_f16 = f16[] constant(0)
419   ROOT conv = f16[128,128,28,28] convolution(param0, param1),
420     window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
421 })";
422 
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)423 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
424   TF_ASSERT_OK_AND_ASSIGN(auto module,
425                           ParseAndReturnUnverifiedModule(kConvHloString));
426   auto* conv = module->entry_computation()->root_instruction();
427   Window w = conv->window();
428   w.mutable_dimensions(0)->set_window_dilation(-1);
429   conv->set_window(w);
430 
431   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
432               HasSubstr("non-positive window dilation factor"));
433 }
434 
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)435 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
436   TF_ASSERT_OK_AND_ASSIGN(auto module,
437                           ParseAndReturnUnverifiedModule(kConvHloString));
438   auto* conv = module->entry_computation()->root_instruction();
439   Window w = conv->window();
440   w.mutable_dimensions(0)->set_base_dilation(-1);
441   conv->set_window(w);
442 
443   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
444               HasSubstr("non-positive base area dilation factor"));
445 }
446 
447 static const char* const kAddWithLayoutChangeHlo = R"(
448    HloModule AddWithLayoutChange
449     ENTRY AddWithLayoutChange {
450       par0 = f32[3,4]{1,0} parameter(0)
451       par1 = f32[3,4]{0,1} parameter(1)
452       ROOT add0 = f32[3,4]{1,0} add(par0,par1)
453     }
454   )";
455 
TEST_F(HloVerifierTest,AddWithLayoutChange)456 TEST_F(HloVerifierTest, AddWithLayoutChange) {
457   TF_ASSERT_OK_AND_ASSIGN(
458       auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
459   auto status = verifier().Run(module.get()).status();
460   ASSERT_TRUE(status.ok());
461 }
462 
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)463 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
464   const char* const kScalarIndexDynamicSlice = R"(
465     HloModule DynamicSlice_module
466 
467     ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
468       %original_parameter = s32[2,2,258] parameter(0)
469       %constant = s32[] constant(0)
470       %start_index = s32[] parameter(1)
471       ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
472     }
473   )";
474 
475   HloModuleConfig config;
476   DebugOptions debug_options = config.debug_options();
477   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
478   config.set_debug_options(debug_options);
479 
480   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
481                                            kScalarIndexDynamicSlice, config));
482   auto status = verifier().Run(module.get()).status();
483   ASSERT_TRUE(status.ok());
484 }
485 
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)486 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
487   const char* const kScalarIndexDynamicSlice = R"(
488     HloModule DynamicUpdateSlice_module
489 
490     ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
491       %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
492       %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
493       %start_index.0 = s32[] parameter(2)
494       %start_index.1 = s32[] parameter(3)
495       %start_index.2 = s32[] parameter(4)
496       %start_index.3 = s32[] parameter(5)
497       ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
498     }
499   )";
500 
501   HloModuleConfig config;
502   DebugOptions debug_options = config.debug_options();
503   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
504   config.set_debug_options(debug_options);
505 
506   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
507                                            kScalarIndexDynamicSlice, config));
508   auto status = verifier().Run(module.get()).status();
509   ASSERT_TRUE(status.ok());
510 }
511 
TEST_F(HloVerifierTestAllowMixedPrecision,DynamicUpdateSliceMixedPrecision)512 TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) {
513   const char* const kDynamicUpdateSliceMixedPrecision = R"(
514     HloModule kDynamicUpdateSliceMixedPrecision
515 
516     ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] {
517       %parameter.0 = f32[32,511,2048] parameter(0)
518       %parameter.1 = bf16[32,511,512] parameter(1)
519       %parameter.2 = s32[] parameter(2)
520       %parameter.3 = s32[] parameter(3)
521       %parameter.4 = s32[] parameter(4)
522       ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4)
523     }
524   )";
525   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
526                                            kDynamicUpdateSliceMixedPrecision));
527   auto status = verifier().Run(module.get()).status();
528   ASSERT_FALSE(status.ok());
529   EXPECT_THAT(status.error_message(),
530               HasSubstr("Expected instruction to have shape equal to "
531                         "f32[32,511,2048], actual shape is bf16[32,511,2048]"));
532 }
533 
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)534 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
535   TF_ASSERT_OK_AND_ASSIGN(
536       auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
537   auto status = verifier().Run(module.get()).status();
538   ASSERT_FALSE(status.ok());
539   EXPECT_THAT(status.error_message(),
540               HasSubstr("Instruction shouldn't change layouts"));
541 }
542 
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)543 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
544   const char* const kSliceWithLayoutChangeHlo = R"(
545    HloModule SliceWithLayoutChange
546     ENTRY SliceWithLayoutChange {
547       par0 = f32[4,5]{0,1} parameter(0)
548       par1 = s32[] parameter(1)
549       par2 = s32[] parameter(2)
550       ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
551         dynamic_slice_sizes={3,4}
552     }
553   )";
554   TF_ASSERT_OK_AND_ASSIGN(
555       auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
556   auto status = verifier().Run(module.get()).status();
557   ASSERT_FALSE(status.ok());
558   EXPECT_THAT(status.error_message(),
559               HasSubstr("Instruction shouldn't change layouts"));
560 }
561 
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)562 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
563   const char* const kConcatWithLayoutChangeHlo = R"(
564    HloModule ConcatWithLayoutChange
565    ENTRY ConcatWithLayoutChange {
566       par0 = f32[3,5]{0,1} parameter(0)
567       par1 = f32[3,3]{1,0} parameter(1)
568       ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
569         dimensions={1}
570    }
571   )";
572   TF_ASSERT_OK_AND_ASSIGN(
573       auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
574   auto status = verifier().Run(module.get()).status();
575   ASSERT_FALSE(status.ok());
576   EXPECT_THAT(status.error_message(),
577               HasSubstr("Instruction shouldn't change layouts"));
578 }
579 
TEST_F(HloVerifierTestLayoutSensitive,BitcastNeedsSameNumberOfElements)580 TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
581   const char* const hlo_string = R"(
582   HloModule Module
583 
584   ENTRY BitcastNeedsToBeNoOp {
585    constant.0 = f32[2] constant({0.0, 0.0})
586    ROOT bitcast = f32[3] bitcast(constant.0)
587   }
588   )";
589   TF_ASSERT_OK_AND_ASSIGN(auto module,
590                           ParseAndReturnUnverifiedModule(hlo_string));
591 
592   auto status = verifier().Run(module.get()).status();
593   ASSERT_FALSE(status.ok());
594   EXPECT_THAT(status.error_message(),
595               HasSubstr("Bitcast cannot have different shape sizes of output "
596                         "(12) and operand (8)"));
597 }
598 
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)599 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
600   const char* const hlo_string = R"(
601   HloModule Module
602 
603   ENTRY SelectMixedPrecisionNotAllowed {
604    p0 = pred[32] parameter(0)
605    p1 = f32[32] parameter(1)
606    p2 = bf16[32] parameter(2)
607    ROOT select = f32[32] select(p0, p1, p2)
608   }
609   )";
610   TF_ASSERT_OK_AND_ASSIGN(auto module,
611                           ParseAndReturnUnverifiedModule(hlo_string));
612 
613   auto status = verifier().Run(module.get()).status();
614   ASSERT_FALSE(status.ok());
615   EXPECT_THAT(status.error_message(),
616               HasSubstr("Seen floating point types of different precisions"));
617 }
618 
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)619 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
620   const char* const hlo_string = R"(
621   HloModule Module
622 
623   ENTRY SelectMixedPrecisionAllowed {
624    p0 = pred[32] parameter(0)
625    p1 = f32[32] parameter(1)
626    p2 = bf16[32] parameter(2)
627    ROOT select = f32[32] select(p0, p1, p2)
628   }
629   )";
630   TF_ASSERT_OK_AND_ASSIGN(auto module,
631                           ParseAndReturnVerifiedModule(hlo_string));
632 
633   auto status = verifier().Run(module.get()).status();
634   ASSERT_TRUE(status.ok());
635 }
636 
TEST_F(HloVerifierTest,SelectTupleNotAllowed)637 TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
638   const char* const hlo_string = R"(
639   HloModule Module
640 
641   ENTRY SelectWithTuple {
642     p0 = (f32[], f32[]) parameter(0)
643     p1 = (f32[], f32[]) parameter(1)
644     p2 = pred[] parameter(2)
645     ROOT select = (f32[], f32[]) select(p2, p0, p1)
646   }
647   )";
648   TF_ASSERT_OK_AND_ASSIGN(auto module,
649                           ParseAndReturnUnverifiedModule(hlo_string));
650 
651   auto status = verifier().Run(module.get()).status();
652   ASSERT_FALSE(status.ok());
653   EXPECT_THAT(status.error_message(),
654               HasSubstr("Expected array argument for select"));
655 }
656 
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDone)657 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
658   const char* const hlo_string = R"(
659   HloModule Module
660 
661   ENTRY CopyStartAndCopyDone {
662     p0 = f32[2,3]{1,0:S(1)} parameter(0)
663     copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
664     ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
665   }
666   )";
667   TF_ASSERT_OK_AND_ASSIGN(auto module,
668                           ParseAndReturnVerifiedModule(hlo_string));
669 
670   auto status = verifier().Run(module.get()).status();
671   ASSERT_TRUE(status.ok());
672 }
673 
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDoneWrongLayout)674 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
675   const char* const hlo_string = R"(
676   HloModule Module
677 
678   ENTRY CopyStartAndCopyDone {
679     p0 = f32[2,3]{1,0:S(1)} parameter(0)
680     copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
681     ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
682   }
683   )";
684   TF_ASSERT_OK_AND_ASSIGN(auto module,
685                           ParseAndReturnUnverifiedModule(hlo_string));
686 
687   auto status = verifier().Run(module.get()).status();
688   ASSERT_FALSE(status.ok());
689   EXPECT_THAT(status.error_message(),
690               HasSubstr("Expected instruction to have shape equal to"));
691 }
692 
TEST_F(HloVerifierTest,CopyStartAndCopyDoneWrongType)693 TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
694   const char* const hlo_string = R"(
695   HloModule Module
696 
697   ENTRY CopyStartAndCopyDone {
698     p0 = f32[2,3] parameter(0)
699     copy-start = f32[2,3] copy-start(p0)
700     ROOT copy-done = f32[2,3] copy-done(copy-start)
701   }
702   )";
703   TF_ASSERT_OK_AND_ASSIGN(auto module,
704                           ParseAndReturnUnverifiedModule(hlo_string));
705 
706   auto status = verifier().Run(module.get()).status();
707   ASSERT_FALSE(status.ok());
708   EXPECT_THAT(status.error_message(),
709               HasSubstr("Expected instruction to have shape equal to "
710                         "(f32[2,3], f32[2,3], u32[])"));
711 }
712 
TEST_F(HloVerifierTest,CopyStartMultipleCopyDone)713 TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
714   const char* const hlo_string = R"(
715   HloModule Module
716 
717   ENTRY CopyStartAndCopyDone {
718     p0 = f32[2,3] parameter(0)
719     copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
720     copy-done.1 = f32[2,3] copy-done(copy-start)
721     copy-done.2 = f32[2,3] copy-done(copy-start)
722     ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
723   }
724   )";
725   TF_ASSERT_OK_AND_ASSIGN(auto module,
726                           ParseAndReturnUnverifiedModule(hlo_string));
727 
728   auto status = verifier().Run(module.get()).status();
729   ASSERT_FALSE(status.ok());
730   EXPECT_THAT(
731       status.error_message(),
732       HasSubstr("copy-start instruction requires one consumer, found 2"));
733 }
734 
TEST_F(HloVerifierTest,CopyDoneNoCopyStart)735 TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
736   const char* const hlo_string = R"(
737   HloModule Module
738 
739   ENTRY CopyStartAndCopyDone {
740     p0 = f32[2,3] parameter(0)
741     p1 = u32[] parameter(1)
742     tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
743     ROOT copy-done = f32[2,3] copy-done(tuple)
744   }
745   )";
746   TF_ASSERT_OK_AND_ASSIGN(auto module,
747                           ParseAndReturnUnverifiedModule(hlo_string));
748 
749   auto status = verifier().Run(module.get()).status();
750   ASSERT_FALSE(status.ok());
751   EXPECT_THAT(status.error_message(),
752               HasSubstr("The operand of a copy-done instruction needs to be "
753                         "copy-start, found tuple"));
754 }
755 
TEST_F(HloVerifierTest,IotaNonArrayResult)756 TEST_F(HloVerifierTest, IotaNonArrayResult) {
757   const char* const hlo_string = R"(
758   HloModule IotaTupleResult
759 
760   ENTRY  kernelEntry {
761     ROOT iota = () iota(), iota_dimension=24
762   }
763   )";
764 
765   TF_ASSERT_OK_AND_ASSIGN(auto module,
766                           ParseAndReturnUnverifiedModule(hlo_string));
767 
768   auto status = verifier().Run(module.get()).status();
769   ASSERT_FALSE(status.ok());
770   EXPECT_THAT(status.error_message(),
771               HasSubstr("does not support non-array result"));
772 }
773 
TEST_F(HloVerifierTest,IotaNegativeDimension)774 TEST_F(HloVerifierTest, IotaNegativeDimension) {
775   const char* const hlo_string = R"(
776   HloModule IotaTupleResult
777 
778   ENTRY  kernelEntry {
779     ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
780   }
781   )";
782 
783   TF_ASSERT_OK_AND_ASSIGN(auto module,
784                           ParseAndReturnUnverifiedModule(hlo_string));
785 
786   auto status = verifier().Run(module.get()).status();
787   ASSERT_FALSE(status.ok());
788   EXPECT_THAT(status.error_message(), HasSubstr("negative"));
789 }
790 
TEST_F(HloVerifierTest,IotaPredResultNotAllowed)791 TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
792   const char* const hlo_string = R"(
793   HloModule IotaPredResult
794 
795   ENTRY  kernelEntry {
796     ROOT iota = pred[128] iota(), iota_dimension=0
797   }
798   )";
799 
800   TF_ASSERT_OK_AND_ASSIGN(auto module,
801                           ParseAndReturnUnverifiedModule(hlo_string));
802 
803   auto status = verifier().Run(module.get()).status();
804   ASSERT_FALSE(status.ok());
805   EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
806 }
807 
808 static const char* const kMapOperandComputationMismatchHlo = R"(
809   HloModule MapOperandComputationMismatch
810 
811   Computation {
812     param0 = f32[] parameter(0)
813     constant = f32[] constant(1)
814     ROOT add = f32[] add(param0, constant)
815   }
816 
817   ENTRY kernelEntry {
818   param = f64[] parameter(0)
819   ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
820 })";
821 
TEST_F(HloVerifierTest,MapOperandComputationMismatch)822 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
823   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
824                                            kMapOperandComputationMismatchHlo));
825   auto status = verifier().Run(module.get()).status();
826   ASSERT_FALSE(status.ok());
827   EXPECT_THAT(
828       status.error_message(),
829       HasSubstr(
830           "Shape mismatch between to_apply computation parameter and operand"));
831 }
832 
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)833 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
834   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
835                                            kMapOperandComputationMismatchHlo));
836   auto status = verifier().Run(module.get()).status();
837   ASSERT_TRUE(status.ok());
838 }
839 
840 static const char* const kReduceOperandComputationMismatchHlo = R"(
841   HloModule ReduceOperandComputationMismatch
842   computation {
843     x = f32[] parameter(0)
844     y = f32[] parameter(1)
845     ROOT add = f32[] add(x, y)
846   }
847 
848   ENTRY kernelEntry {
849     arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
850     constant = f16[] constant(0)
851     reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
852   })";
853 
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)854 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
855   TF_ASSERT_OK_AND_ASSIGN(
856       auto module,
857       ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
858   auto status = verifier().Run(module.get()).status();
859   ASSERT_FALSE(status.ok());
860   EXPECT_THAT(status.error_message(),
861               HasSubstr("Expected instruction to have shape equal to f32[64]"));
862 }
863 
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)864 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
865   TF_ASSERT_OK_AND_ASSIGN(
866       auto module,
867       ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
868   auto status = verifier().Run(module.get()).status();
869   ASSERT_TRUE(status.ok());
870 }
871 
ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups)872 string ReplicaGroupsStr(std::vector<std::vector<int64>> replica_groups) {
873   std::vector<string> replica_group_strs;
874   for (const auto& g : replica_groups) {
875     replica_group_strs.push_back(
876         absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
877   }
878   return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
879 }
880 
ReplicaCount(const std::vector<std::vector<int64>> & replica_groups)881 int64 ReplicaCount(const std::vector<std::vector<int64>>& replica_groups) {
882   int64_t replica_count = 0;
883   for (auto group : replica_groups) {
884     replica_count += group.size();
885   }
886   return replica_count;
887 }
888 
MakeCollectiveCommOpComputation(std::vector<std::vector<int64>> replica_groups,absl::optional<int64> replica_count,absl::optional<int64> num_partitions,absl::string_view other_attributes,absl::string_view template_str)889 StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
890     std::vector<std::vector<int64>> replica_groups,
891     absl::optional<int64> replica_count, absl::optional<int64> num_partitions,
892     absl::string_view other_attributes, absl::string_view template_str) {
893   HloModuleConfig config;
894   config.set_replica_count(
895       replica_count.value_or(ReplicaCount(replica_groups)));
896   config.set_num_partitions(num_partitions.value_or(1));
897   return ParseAndReturnUnverifiedModule(
898       absl::StrReplaceAll(
899           template_str,
900           {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
901            {"OTHER_ATTRIBUTES", other_attributes.empty()
902                                     ? ""
903                                     : absl::StrCat(",", other_attributes)}}),
904       config);
905 }
906 
MakeAllReduceComputation(std::vector<std::vector<int64>> replica_groups,absl::optional<int64> replica_count=absl::nullopt,absl::optional<int64> num_partitions=absl::nullopt,absl::string_view other_attributes="")907 StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
908     std::vector<std::vector<int64>> replica_groups,
909     absl::optional<int64> replica_count = absl::nullopt,
910     absl::optional<int64> num_partitions = absl::nullopt,
911     absl::string_view other_attributes = "") {
912   const char* kTemplate = R"(
913   HloModule test
914   add {
915     x = f32[] parameter(0)
916     y = f32[] parameter(1)
917     ROOT add = f32[] add(x, y)
918   }
919   ENTRY entry {
920     p = f32[128]{0} parameter(0)
921     crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
922                                      OTHER_ATTRIBUTES
923   })";
924   return MakeCollectiveCommOpComputation(replica_groups, replica_count,
925                                          num_partitions, other_attributes,
926                                          kTemplate);
927 }
928 
TEST_F(HloVerifierTest,AllReduce_NoReplicaGroupsOK)929 TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
930   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
931   TF_ASSERT_OK(verifier().Run(module.get()).status());
932 }
933 
TEST_F(HloVerifierTest,AllReduce_DifferentGroupSizesOk)934 TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
935   TF_ASSERT_OK_AND_ASSIGN(auto module,
936                           MakeAllReduceComputation({{0}, {1, 3}, {2}}));
937   TF_ASSERT_OK(verifier().Run(module.get()).status());
938 }
939 
TEST_F(HloVerifierTest,AllReduce_EmptyReplicaGroup)940 TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
941   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
942   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
943               HasSubstr("empty replica group"));
944 }
945 
TEST_F(HloVerifierTest,AllReduce_RepeatedReplicaId)946 TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
947   TF_ASSERT_OK_AND_ASSIGN(auto module,
948                           MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
949   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
950               HasSubstr("Replica 0 is repeated"));
951 }
952 
TEST_F(HloVerifierTest,AllReduce_MissingReplicaId)953 TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
954   TF_ASSERT_OK_AND_ASSIGN(auto module,
955                           MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
956   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
957               HasSubstr("Replica 4 is not named"));
958 }
959 
TEST_F(HloVerifierTest,AllReduce_NotEnougReplicasInGroupConfig)960 TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
961   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
962   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
963               HasSubstr("In kCrossReplica mode, replica groups should contain "
964                         "8 replicas, but found 2"));
965 }
966 
TEST_F(HloVerifierTest,AllReduce_TooManyReplicasInGroupConfig)967 TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
968   TF_ASSERT_OK_AND_ASSIGN(auto module,
969                           MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
970   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
971               HasSubstr("In kCrossReplica mode, replica groups should contain "
972                         "2 replicas, but found 4"));
973 }
974 
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Invalid)975 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
976   TF_ASSERT_OK_AND_ASSIGN(
977       auto module,
978       MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
979   EXPECT_THAT(
980       verifier().Run(module.get()).status().error_message(),
981       HasSubstr(
982           "In kCrossReplicaAndPartition mode, replica groups should contain "
983           "2 replicas, but found 4"));
984 }
985 
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Valid)986 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
987   TF_ASSERT_OK_AND_ASSIGN(
988       auto module,
989       MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
990   TF_ASSERT_OK(verifier().Run(module.get()).status());
991 }
992 
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Invalid)993 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
994   TF_ASSERT_OK_AND_ASSIGN(
995       auto module,
996       MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
997                                "channel_id=1, use_global_device_ids=true"));
998   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
999               HasSubstr("In kFlattenedID mode, replica groups should contain "
1000                         "2 flattened IDs, but found 4"));
1001 }
1002 
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Valid)1003 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
1004   TF_ASSERT_OK_AND_ASSIGN(
1005       auto module,
1006       MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
1007                                "channel_id=1, use_global_device_ids=true"));
1008   TF_ASSERT_OK(verifier().Run(module.get()).status());
1009 }
1010 
TEST_F(HloVerifierTest,AllReduceStartAndDone)1011 TEST_F(HloVerifierTest, AllReduceStartAndDone) {
1012   const char* const kModuleStr = R"(
1013   HloModule test
1014   add {
1015     x = f32[] parameter(0)
1016     y = f32[] parameter(1)
1017     ROOT add = f32[] add(x, y)
1018   }
1019   ENTRY entry {
1020     p0 = f32[2,3] parameter(0)
1021     start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1022     ROOT done = f32[2,3] all-reduce-done(start)
1023   }
1024   )";
1025   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1026                           ParseAndReturnUnverifiedModule(kModuleStr));
1027 
1028   auto status = verifier().Run(module.get()).status();
1029   ASSERT_TRUE(status.ok());
1030 }
1031 
TEST_F(HloVerifierTest,AllReduceStartAndDoneWrongType)1032 TEST_F(HloVerifierTest, AllReduceStartAndDoneWrongType) {
1033   const char* const kModuleStr = R"(
1034   HloModule test
1035   add {
1036     x = f32[] parameter(0)
1037     y = f32[] parameter(1)
1038     ROOT add = f32[] add(x, y)
1039   }
1040   ENTRY entry {
1041     p0 = f32[2,3] parameter(0)
1042     start = f32[2,3] all-reduce-start(p0), to_apply=add
1043     ROOT done = f32[2,3] all-reduce-done(start)
1044   }
1045   )";
1046   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1047                           ParseAndReturnUnverifiedModule(kModuleStr));
1048 
1049   auto status = verifier().Run(module.get()).status();
1050   EXPECT_THAT(status.error_message(),
1051               HasSubstr("Expected instruction to have shape equal to "
1052                         "(f32[2,3], f32[2,3])"));
1053 }
1054 
TEST_F(HloVerifierTest,AllReduceStartAndMultipleDone)1055 TEST_F(HloVerifierTest, AllReduceStartAndMultipleDone) {
1056   const char* const kModuleStr = R"(
1057   HloModule test
1058   add {
1059     x = f32[] parameter(0)
1060     y = f32[] parameter(1)
1061     ROOT add = f32[] add(x, y)
1062   }
1063   ENTRY entry {
1064     p0 = f32[2,3] parameter(0)
1065     start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1066     done1 = f32[2,3] all-reduce-done(start)
1067     ROOT done2 = f32[2,3] all-reduce-done(start)
1068   }
1069   )";
1070   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1071                           ParseAndReturnUnverifiedModule(kModuleStr));
1072 
1073   auto status = verifier().Run(module.get()).status();
1074   ASSERT_FALSE(status.ok());
1075   EXPECT_THAT(
1076       status.error_message(),
1077       HasSubstr("all-reduce-start instruction requires one consumer, found 2"));
1078 }
1079 
TEST_F(HloVerifierTest,AllReduceDoneWithoutStart)1080 TEST_F(HloVerifierTest, AllReduceDoneWithoutStart) {
1081   const char* const kModuleStr = R"(
1082   HloModule test
1083   ENTRY entry {
1084     p0 = f32[2,3] parameter(0)
1085     p1 = u32[] parameter(1)
1086     tuple = (f32[2,3], f32[2,3]) tuple(p0, p0, p1, p1)
1087     ROOT done = f32[2,3] all-reduce-done(tuple)
1088   }
1089   )";
1090   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1091                           ParseAndReturnUnverifiedModule(kModuleStr));
1092 
1093   auto status = verifier().Run(module.get()).status();
1094   ASSERT_FALSE(status.ok());
1095   EXPECT_THAT(status.error_message(),
1096               HasSubstr("The operand of a all-reduce-done instruction "
1097                         "needs to be all-reduce-start, found tuple"));
1098 }
1099 
MakeAllToAllComputation(std::vector<std::vector<int64>> replica_groups,absl::optional<int64> replica_count=absl::nullopt,absl::optional<int64> num_partitions=absl::nullopt,absl::string_view other_attributes="")1100 StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
1101     std::vector<std::vector<int64>> replica_groups,
1102     absl::optional<int64> replica_count = absl::nullopt,
1103     absl::optional<int64> num_partitions = absl::nullopt,
1104     absl::string_view other_attributes = "") {
1105   const char* kTemplate = R"(
1106   HloModule test
1107   ENTRY entry {
1108     p0 = f32[128]{0} parameter(0)
1109     p1 = f32[128]{0} parameter(1)
1110     a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
1111                                                    OTHER_ATTRIBUTES
1112   })";
1113   return MakeCollectiveCommOpComputation(replica_groups, replica_count,
1114                                          num_partitions, other_attributes,
1115                                          kTemplate);
1116 }
1117 
TEST_F(HloVerifierTest,AllToAll_NoReplicaGroupsOK)1118 TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
1119   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
1120   TF_ASSERT_OK(verifier().Run(module.get()).status());
1121 }
1122 
TEST_F(HloVerifierTest,AllToAll_EmptyReplicaGroup)1123 TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
1124   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
1125   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1126               HasSubstr("cannot have an empty replica group"));
1127 }
1128 
TEST_F(HloVerifierTest,AllToAll_RepeatedReplicaId)1129 TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
1130   TF_ASSERT_OK_AND_ASSIGN(auto module,
1131                           MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
1132   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1133               HasSubstr("Replica 0 is repeated"));
1134 }
1135 
TEST_F(HloVerifierTest,AllToAll_MissingReplicaId)1136 TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
1137   TF_ASSERT_OK_AND_ASSIGN(auto module,
1138                           MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
1139   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1140               HasSubstr("Replica 4 is not named"));
1141 }
1142 
TEST_F(HloVerifierTest,AllToAll_UniformSizeOfReplicasInGroup)1143 TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
1144   TF_ASSERT_OK_AND_ASSIGN(auto module,
1145                           MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
1146   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1147               HasSubstr("Replica groups expected to be of uniform size"));
1148 }
1149 
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Invalid)1150 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
1151   TF_ASSERT_OK_AND_ASSIGN(
1152       auto module,
1153       MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
1154   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1155               HasSubstr("In kCrossPartition mode, replica groups should "
1156                         "contain 2 partitions, but found 4"));
1157 }
1158 
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Valid)1159 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
1160   TF_ASSERT_OK_AND_ASSIGN(
1161       auto module,
1162       MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
1163   TF_ASSERT_OK(verifier().Run(module.get()).status());
1164 }
1165 
TEST_F(HloVerifierTest,AllToAll_LayoutConstrained)1166 TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
1167   const char* const kModuleStr = R"(
1168   HloModule test
1169   ENTRY entry {
1170     p0 = f32[128,4]{0,1} parameter(0)
1171     p1 = f32[128,4]{1,0} parameter(1)
1172     ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
1173       replica_groups={{0,1}}
1174   }
1175   )";
1176   HloModuleConfig config;
1177   config.set_replica_count(2);
1178   TF_ASSERT_OK_AND_ASSIGN(auto module,
1179                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1180   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1181               HasSubstr("HLO all-to-all has operands with different shapes"));
1182 }
1183 
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTwice)1184 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
1185   const char* const kModuleStr = R"(
1186   HloModule test
1187   ENTRY entry {
1188     p0 = f32[128] parameter(0)
1189     ROOT permute = f32[128] collective-permute(p0),
1190       source_target_pairs={{0,1}, {0,2}, {1,0}}
1191   }
1192   )";
1193   HloModuleConfig config;
1194   config.set_replica_count(3);
1195   TF_ASSERT_OK_AND_ASSIGN(auto module,
1196                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1197   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1198               HasSubstr("Source 0 appears more than once"));
1199 }
1200 
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTwice)1201 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
1202   const char* const kModuleStr = R"(
1203   HloModule test
1204   ENTRY entry {
1205     p0 = f32[128] parameter(0)
1206     ROOT permute = f32[128] collective-permute(p0),
1207       source_target_pairs={{0,2}, {1,2}, {2,0}}
1208   }
1209   )";
1210   TF_ASSERT_OK_AND_ASSIGN(auto module,
1211                           ParseAndReturnUnverifiedModule(kModuleStr));
1212   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1213               HasSubstr("Target 2 appears more than once"));
1214 }
1215 
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTooManyTimes)1216 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTooManyTimes) {
1217   const char* const kModuleStr = R"(
1218   HloModule test
1219   ENTRY entry {
1220     replica_id = u32[] replica-id()
1221     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1222     constant.1 = u32[] constant(1000)
1223     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1224     constant.2 = s32[] constant(0)
1225     constant.3 = s32[] constant(1)
1226     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1227     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1228     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1229     ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{0,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1230   }
1231   )";
1232   TF_ASSERT_OK_AND_ASSIGN(auto module,
1233                           ParseAndReturnUnverifiedModule(kModuleStr));
1234   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1235               HasSubstr("Source 0 appears more than 2 times in instruction's "
1236                         "source-target pairs:"));
1237 }
1238 
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTooManyTimes)1239 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTooManyTimes) {
1240   const char* const kModuleStr = R"(
1241   HloModule test
1242   ENTRY entry {
1243     replica_id = u32[] replica-id()
1244     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1245     constant.1 = u32[] constant(1000)
1246     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1247     constant.2 = s32[] constant(0)
1248     constant.3 = s32[] constant(1)
1249     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1250     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1251     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1252     ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,3},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1253   }
1254   )";
1255   TF_ASSERT_OK_AND_ASSIGN(auto module,
1256                           ParseAndReturnUnverifiedModule(kModuleStr));
1257   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1258               HasSubstr("Target 3 appears more than 2 times in instruction's "
1259                         "source-target pairs:"));
1260 }
1261 
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingSourceTarget)1262 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingSourceTarget) {
1263   const char* const kModuleStr = R"(
1264   HloModule test
1265   ENTRY entry {
1266     replica_id = u32[] replica-id()
1267     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1268     constant.1 = u32[] constant(1000)
1269     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1270     broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1271     constant.2 = s32[] constant(0)
1272     constant.3 = s32[] constant(1)
1273     tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1274     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1275     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1276     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1277     constant.4 = s32[] constant(2)
1278     tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1279     tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1280     tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1281     ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1282   }
1283   )";
1284   TF_ASSERT_OK_AND_ASSIGN(auto module,
1285                           ParseAndReturnUnverifiedModule(kModuleStr));
1286   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1287               HasSubstr("Unmatching input buffers and output buffers"));
1288 }
1289 
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingInputAndInputOffset)1290 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingInputAndInputOffset) {
1291   const char* const kModuleStr = R"(
1292   HloModule test
1293   ENTRY entry {
1294     replica_id = u32[] replica-id()
1295     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1296     constant.1 = u32[] constant(1000)
1297     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1298     broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1299     constant.2 = s32[] constant(0)
1300     constant.3 = s32[] constant(1)
1301     tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1302     tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1303     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1304     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1305     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1306     constant.4 = s32[] constant(2)
1307     tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1308     tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1309     tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1310     ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (s32[],s32[],s32[]) tuple.3, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1311   }
1312   )";
1313   TF_ASSERT_OK_AND_ASSIGN(auto module,
1314                           ParseAndReturnUnverifiedModule(kModuleStr));
1315   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1316               HasSubstr("Unmatching input buffers and input offset."));
1317 }
1318 
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingOutputAndOutputOffset)1319 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingOutputAndOutputOffset) {
1320   const char* const kModuleStr = R"(
1321   HloModule test
1322   ENTRY entry {
1323     replica_id = u32[] replica-id()
1324       broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1325       constant.1 = u32[] constant(1000)
1326       broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1327       broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1328       constant.2 = s32[] constant(0)
1329       constant.3 = s32[] constant(1)
1330       tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1331       tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1332       tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1333       tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1334       tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1335       constant.4 = s32[] constant(2)
1336       tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1337       tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
1338       tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
1339       ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (s32[],s32[],s32[]) tuple.2), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1340   }
1341   )";
1342   TF_ASSERT_OK_AND_ASSIGN(auto module,
1343                           ParseAndReturnUnverifiedModule(kModuleStr));
1344   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1345               HasSubstr("Unmatching output buffers and output offset."));
1346 }
1347 
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaSourceOOR)1348 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaSourceOOR) {
1349   const char* const kModuleStr = R"(
1350   HloModule test
1351   ENTRY entry {
1352     p0 = f32[128] parameter(0)
1353     ROOT permute = f32[128] collective-permute(p0),
1354       source_target_pairs={{5,2}, {1,2}, {2,0}}
1355   }
1356   )";
1357   HloModuleConfig config;
1358   config.set_replica_count(3);
1359   TF_ASSERT_OK_AND_ASSIGN(auto module,
1360                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1361   const std::string error_message =
1362       verifier().Run(module.get()).status().error_message();
1363   EXPECT_THAT(error_message, HasSubstr("Source 5"));
1364   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1365 }
1366 
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaTargetOOR)1367 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaTargetOOR) {
1368   const char* const kModuleStr = R"(
1369   HloModule test
1370   ENTRY entry {
1371     p0 = f32[128] parameter(0)
1372     ROOT permute = f32[128] collective-permute(p0),
1373       source_target_pairs={{0,1}, {1,2}, {2,7}}
1374   }
1375   )";
1376   HloModuleConfig config;
1377   config.set_replica_count(3);
1378   TF_ASSERT_OK_AND_ASSIGN(auto module,
1379                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1380   const std::string error_message =
1381       verifier().Run(module.get()).status().error_message();
1382   EXPECT_THAT(error_message, HasSubstr("Target 7"));
1383   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1384 }
1385 
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionSourceOOR)1386 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionSourceOOR) {
1387   const char* const kModuleStr = R"(
1388   HloModule test
1389   ENTRY entry {
1390     p0 = f32[128] parameter(0)
1391     ROOT permute = f32[128] collective-permute(p0),
1392       source_target_pairs={{5,2}, {1,2}, {2,0}}, channel_id=1
1393   }
1394   )";
1395   HloModuleConfig config;
1396   config.set_num_partitions(3);
1397   TF_ASSERT_OK_AND_ASSIGN(auto module,
1398                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1399   const std::string error_message =
1400       verifier().Run(module.get()).status().error_message();
1401   EXPECT_THAT(error_message, HasSubstr("Source 5"));
1402   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1403 }
1404 
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionTargetOOR)1405 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) {
1406   const char* const kModuleStr = R"(
1407   HloModule test
1408   ENTRY entry {
1409     p0 = f32[128] parameter(0)
1410     ROOT permute = f32[128] collective-permute(p0),
1411       source_target_pairs={{0,2}, {1,7}, {2,0}}, channel_id=1
1412   }
1413   )";
1414   HloModuleConfig config;
1415   config.set_num_partitions(3);
1416   TF_ASSERT_OK_AND_ASSIGN(auto module,
1417                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1418   const std::string error_message =
1419       verifier().Run(module.get()).status().error_message();
1420   EXPECT_THAT(error_message, HasSubstr("Target 7"));
1421   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1422 }
1423 
TEST_F(HloVerifierTest,FusionShapeVerifier)1424 TEST_F(HloVerifierTest, FusionShapeVerifier) {
1425   const char* const kModuleStr = R"(
1426   HloModule test
1427 
1428   fused_computation {
1429     ROOT p0 = f32[10,10] parameter(0)
1430   }
1431 
1432   ENTRY entry {
1433     p0 = f32[10,10] parameter(0)
1434     ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
1435   }
1436   )";
1437   TF_ASSERT_OK_AND_ASSIGN(auto module,
1438                           ParseAndReturnUnverifiedModule(kModuleStr));
1439   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1440               HasSubstr("Fused computation shape"));
1441 }
1442 
TEST_F(HloVerifierTest,AllReduceVerifier)1443 TEST_F(HloVerifierTest, AllReduceVerifier) {
1444   const char* const kModuleStr = R"(
1445   HloModule test
1446 
1447   add {
1448     lhs = f32[] parameter(0)
1449     rhs = f32[] parameter(1)
1450     ROOT add = f32[] add(lhs, rhs)
1451   }
1452 
1453   ENTRY entry {
1454     input = f32[8,12]{0,1} parameter(0)
1455     crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
1456     crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
1457       constrain_layout=true
1458     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
1459   }
1460   )";
1461   TF_ASSERT_OK_AND_ASSIGN(auto module,
1462                           ParseAndReturnUnverifiedModule(kModuleStr));
1463   EXPECT_THAT(
1464       verifier().Run(module.get()).status().error_message(),
1465       HasSubstr("mix of layout constrained and unconstrained AllReduce"));
1466 }
1467 
TEST_F(HloVerifierTest,ChannelVerifier)1468 TEST_F(HloVerifierTest, ChannelVerifier) {
1469   const char* const kModuleStr = R"(
1470   HloModule test
1471 
1472   add {
1473     lhs = f32[] parameter(0)
1474     rhs = f32[] parameter(1)
1475     ROOT add = f32[] add(lhs, rhs)
1476   }
1477 
1478   ENTRY entry {
1479     %input = f32[8,12] parameter(0)
1480     %token0 = token[] after-all()
1481     %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
1482     %send-done = token[] send-done(%send), channel_id=1
1483     %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1484       channel_id=1
1485     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
1486   }
1487   )";
1488   TF_ASSERT_OK_AND_ASSIGN(auto module,
1489                           ParseAndReturnUnverifiedModule(kModuleStr));
1490   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1491               HasSubstr("used for different types of channel instructions"));
1492 }
1493 
TEST_F(HloVerifierTest,CollectiveChannelVerifier)1494 TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
1495   const char* const kModuleStr = R"(
1496   HloModule test
1497 
1498   add {
1499     lhs = f32[] parameter(0)
1500     rhs = f32[] parameter(1)
1501     ROOT add = f32[] add(lhs, rhs)
1502   }
1503 
1504   ENTRY entry {
1505     %input = f32[8,12] parameter(0)
1506     %permute = f32[8,12] collective-permute(%input),
1507       source_target_pairs={{0,1},{1,0}}, channel_id=1
1508     %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
1509       channel_id=1
1510     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
1511   }
1512   )";
1513   TF_ASSERT_OK_AND_ASSIGN(auto module,
1514                           ParseAndReturnUnverifiedModule(kModuleStr));
1515   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1516               HasSubstr("used for different types of channel instructions"));
1517 }
1518 
TEST_F(HloVerifierTestLayoutSensitive,CollectivePermuteStartAndDone)1519 TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
1520   const char* const kModuleStr = R"(
1521   HloModule Module
1522 
1523   ENTRY CollectivePermuteStartAndDone {
1524     p0 = f32[2,3]{1,0:S(1)} parameter(0)
1525     collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
1526     ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1527   }
1528   )";
1529   TF_ASSERT_OK_AND_ASSIGN(auto module,
1530                           ParseAndReturnUnverifiedModule(kModuleStr));
1531 
1532   auto status = verifier().Run(module.get()).status();
1533   ASSERT_TRUE(status.ok());
1534 }
1535 
TEST_F(HloVerifierTest,CollectivePermuteStartAndDoneWrongType)1536 TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
1537   const char* const kModuleStr = R"(
1538   HloModule Module
1539 
1540   ENTRY CollectivePermuteStartAndDoneWrongType {
1541     p0 = f32[2,3]{1,0:S(1)} parameter(0)
1542     collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
1543     ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1544   }
1545   )";
1546   TF_ASSERT_OK_AND_ASSIGN(auto module,
1547                           ParseAndReturnUnverifiedModule(kModuleStr));
1548 
1549   auto status = verifier().Run(module.get()).status();
1550   ASSERT_FALSE(status.ok());
1551   EXPECT_THAT(status.error_message(),
1552               HasSubstr("Expected instruction to have shape equal to "
1553                         "(f32[2,3], f32[2,3], u32[], u32[])"));
1554 }
1555 
TEST_F(HloVerifierTest,CollectivePermuteStartAndMultipleDone)1556 TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
1557   const char* const kModuleStr = R"(
1558   HloModule Module
1559 
1560   ENTRY CollectivePermuteStartAndMultipleDone {
1561     p0 = f32[2,3]{1,0:S(1)} parameter(0)
1562     collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
1563     collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1564     ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
1565   }
1566   )";
1567   TF_ASSERT_OK_AND_ASSIGN(auto module,
1568                           ParseAndReturnUnverifiedModule(kModuleStr));
1569 
1570   auto status = verifier().Run(module.get()).status();
1571   ASSERT_FALSE(status.ok());
1572   EXPECT_THAT(
1573       status.error_message(),
1574       HasSubstr("collective-permute-start instruction requires one consumer, "
1575                 "found 2"));
1576 }
1577 
TEST_F(HloVerifierTest,CollectivePermuteDoneNoCollectivePermuteStart)1578 TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
1579   const char* const kModuleStr = R"(
1580   HloModule Module
1581 
1582   ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
1583     p0 = f32[2,3]{1,0:S(1)} parameter(0)
1584     p1 = f32[2,3]{1,0:S(1)} parameter(1)
1585     p2 = u32[] parameter(2)
1586     p3 = u32[] parameter(3)
1587     tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
1588     ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
1589   }
1590   )";
1591   TF_ASSERT_OK_AND_ASSIGN(auto module,
1592                           ParseAndReturnUnverifiedModule(kModuleStr));
1593 
1594   auto status = verifier().Run(module.get()).status();
1595   ASSERT_FALSE(status.ok());
1596   EXPECT_THAT(status.error_message(),
1597               HasSubstr("The operand of a collective-permute-done instruction "
1598                         "needs to be collective-permute-start, found tuple"));
1599 }
1600 
TEST_F(HloVerifierTest,ComparisonTypeFloat)1601 TEST_F(HloVerifierTest, ComparisonTypeFloat) {
1602   const char* const hlo_string = R"(
1603   HloModule Module
1604 
1605   ENTRY RngOperandElementTypesNotMatch {
1606    p0 = f32[] parameter(0)
1607    ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
1608   }
1609   )";
1610   TF_ASSERT_OK_AND_ASSIGN(auto module,
1611                           ParseAndReturnUnverifiedModule(hlo_string));
1612 
1613   auto status = verifier().Run(module.get()).status();
1614   ASSERT_FALSE(status.ok());
1615   EXPECT_THAT(status.error_message(),
1616               HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
1617 }
1618 
TEST_F(HloVerifierTest,ComparisonTypeSigned)1619 TEST_F(HloVerifierTest, ComparisonTypeSigned) {
1620   const char* const hlo_string = R"(
1621   HloModule Module
1622 
1623   ENTRY RngOperandElementTypesNotMatch {
1624    p0 = s32[] parameter(0)
1625    ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
1626   }
1627   )";
1628   TF_ASSERT_OK_AND_ASSIGN(auto module,
1629                           ParseAndReturnUnverifiedModule(hlo_string));
1630 
1631   auto status = verifier().Run(module.get()).status();
1632   ASSERT_FALSE(status.ok());
1633   EXPECT_THAT(status.error_message(),
1634               HasSubstr("Expected comparison type SIGNED"));
1635 }
1636 
TEST_F(HloVerifierTest,ComparisonTypeUnsigned)1637 TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
1638   const char* const hlo_string = R"(
1639   HloModule Module
1640 
1641   ENTRY RngOperandElementTypesNotMatch {
1642    p0 = u32[] parameter(0)
1643    ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
1644   }
1645   )";
1646   TF_ASSERT_OK_AND_ASSIGN(auto module,
1647                           ParseAndReturnUnverifiedModule(hlo_string));
1648 
1649   auto status = verifier().Run(module.get()).status();
1650   ASSERT_FALSE(status.ok());
1651   EXPECT_THAT(status.error_message(),
1652               HasSubstr("Expected comparison type UNSIGNED"));
1653 }
1654 
TEST_F(HloVerifierTest,ComparisonTypePred)1655 TEST_F(HloVerifierTest, ComparisonTypePred) {
1656   const char* const hlo_string = R"(
1657   HloModule Module
1658 
1659   ENTRY RngOperandElementTypesNotMatch {
1660    p0 = pred[] parameter(0)
1661    ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
1662   }
1663   )";
1664   TF_ASSERT_OK_AND_ASSIGN(auto module,
1665                           ParseAndReturnUnverifiedModule(hlo_string));
1666 
1667   auto status = verifier().Run(module.get()).status();
1668   ASSERT_FALSE(status.ok());
1669   EXPECT_THAT(status.error_message(),
1670               HasSubstr("Expected comparison type UNSIGNED"));
1671 }
1672 
TEST_F(HloVerifierTest,UseGlobalDeviceIdsEmptyReplicaGroup)1673 TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
1674   const char* const hlo_string = R"(
1675   HloModule Module
1676   add {
1677     lhs = f32[] parameter(0)
1678     rhs = f32[] parameter(1)
1679     ROOT add = f32[] add(lhs, rhs)
1680   }
1681 
1682   ENTRY CRS {
1683     input = f32[8]{0} parameter(0)
1684     ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
1685                          use_global_device_ids=true, to_apply=add
1686   })";
1687   TF_ASSERT_OK_AND_ASSIGN(auto module,
1688                           ParseAndReturnUnverifiedModule(hlo_string));
1689 
1690   auto status = verifier().Run(module.get()).status();
1691   ASSERT_FALSE(status.ok());
1692   EXPECT_THAT(
1693       status.error_message(),
1694       HasSubstr("Replica groups must be specified in flattened-id mode"));
1695 }
1696 
TEST_F(HloVerifierTest,InvalidChannelIDandUseGlobalDeviceIDs)1697 TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
1698   const char* const hlo_string = R"(
1699   HloModule Module
1700   add {
1701     lhs = f32[] parameter(0)
1702     rhs = f32[] parameter(1)
1703     ROOT add = f32[] add(lhs, rhs)
1704   }
1705 
1706   ENTRY CRS {
1707     input = f32[8]{0} parameter(0)
1708     ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
1709                          use_global_device_ids=true, to_apply=add
1710   })";
1711   TF_ASSERT_OK_AND_ASSIGN(auto module,
1712                           ParseAndReturnUnverifiedModule(hlo_string));
1713 
1714   auto status = verifier().Run(module.get()).status();
1715   ASSERT_FALSE(status.ok());
1716   EXPECT_THAT(
1717       status.error_message(),
1718       HasSubstr(
1719           "Invalid combination of has_channel_id and use_global_device_ids"));
1720 }
1721 
TEST_F(HloVerifierTest,ReduceScatterInvalidOutputSize0)1722 TEST_F(HloVerifierTest, ReduceScatterInvalidOutputSize0) {
1723   const char* const hlo_string = R"(
1724   HloModule Module
1725   add {
1726     lhs = f32[] parameter(0)
1727     rhs = f32[] parameter(1)
1728     ROOT add = f32[] add(lhs, rhs)
1729   }
1730 
1731   ENTRY CRS {
1732     input = f32[8]{0} parameter(0)
1733     ROOT crs = f32[8]{0} reduce-scatter(input), replica_groups={{0,1}},
1734                          to_apply=add, dimensions={0}
1735   })";
1736   TF_ASSERT_OK_AND_ASSIGN(auto module,
1737                           ParseAndReturnUnverifiedModule(hlo_string));
1738 
1739   auto status = verifier().Run(module.get()).status();
1740   ASSERT_FALSE(status.ok());
1741   EXPECT_THAT(status.error_message(),
1742               HasSubstr("shard_count = 1, subgroup_size = 2"));
1743 }
1744 
TEST_F(HloVerifierTest,ReduceScatterInvalidScatterDim)1745 TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) {
1746   const char* const hlo_string = R"(
1747   HloModule Module
1748   add {
1749     lhs = f32[] parameter(0)
1750     rhs = f32[] parameter(1)
1751     ROOT add = f32[] add(lhs, rhs)
1752   }
1753 
1754   ENTRY CRS {
1755     input = f32[8]{0} parameter(0)
1756     ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}},
1757                          to_apply=add, dimensions={1}
1758   })";
1759   TF_ASSERT_OK_AND_ASSIGN(auto module,
1760                           ParseAndReturnUnverifiedModule(hlo_string));
1761 
1762   auto status = verifier().Run(module.get()).status();
1763   ASSERT_FALSE(status.ok());
1764   EXPECT_THAT(
1765       status.error_message(),
1766       HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()"));
1767 }
1768 
TEST_F(HloVerifierTest,ReduceScatterNonUniformGroups)1769 TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
1770   const char* const hlo_string = R"(
1771   HloModule Module
1772   add {
1773     lhs = f32[] parameter(0)
1774     rhs = f32[] parameter(1)
1775     ROOT add = f32[] add(lhs, rhs)
1776   }
1777 
1778   ENTRY CRS {
1779     input = f32[8]{0} parameter(0)
1780     ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}, {2,3,4}},
1781                          to_apply=add, dimensions={0}
1782   })";
1783   TF_ASSERT_OK_AND_ASSIGN(auto module,
1784                           ParseAndReturnUnverifiedModule(hlo_string));
1785 
1786   auto status = verifier().Run(module.get()).status();
1787   ASSERT_FALSE(status.ok());
1788   EXPECT_THAT(status.error_message(),
1789               HasSubstr("Replica groups expected to be of uniform size"));
1790 }
1791 
1792 }  // namespace
1793 }  // namespace xla
1794