• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_replace.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/layout_assignment.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla.pb.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 
34 namespace xla {
35 namespace {
36 
37 using ::testing::HasSubstr;
38 
CreateUnverifiedModule()39 std::unique_ptr<HloModule> CreateUnverifiedModule() {
40   return std::make_unique<HloModule>("module", HloModuleConfig());
41 }
42 
43 // This class cannot be converted to use HloTestBase. It explicitly
44 // uses HloTestBase to create and test malformed HLOs.
45 class HloVerifierTest : public HloTestBase {
46  public:
HloVerifierTest()47   HloVerifierTest()
48       : HloTestBase(/*verifier_layout_sensitive=*/false,
49                     /*allow_mixed_precision_in_hlo_verifier=*/false) {}
50 };
51 
52 class HloVerifierTestAllowMixedPrecision : public HloTestBase {
53  public:
HloVerifierTestAllowMixedPrecision()54   HloVerifierTestAllowMixedPrecision()
55       : HloTestBase(/*verifier_layout_sensitive=*/false,
56                     /*allow_mixed_precision_in_hlo_verifier=*/true) {}
57 };
58 
59 class HloVerifierTestLayoutSensitive : public HloTestBase {
60  public:
HloVerifierTestLayoutSensitive()61   HloVerifierTestLayoutSensitive()
62       : HloTestBase(/*verifier_layout_sensitive=*/true,
63                     /*allow_mixed_precision_in_hlo_verifier=*/false,
64                     LayoutAssignment::InstructionCanChangeLayout) {}
65 };
66 
67 class HloVerifierTestLayoutFusion : public HloTestBase {
68  public:
HloVerifierTestLayoutFusion()69   HloVerifierTestLayoutFusion()
70       : HloTestBase(/*verifier_layout_sensitive=*/true,
71                     /*allow_mixed_precision_in_hlo_verifier=*/false) {}
72 };
73 
TEST_F(HloVerifierTest,NullInstructionParent)74 TEST_F(HloVerifierTest, NullInstructionParent) {
75   HloComputation::Builder builder(TestName());
76   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
77   HloInstruction* param = builder.AddInstruction(
78       HloInstruction::CreateParameter(0, scalar_shape, "param"));
79   HloInstruction* negate = builder.AddInstruction(
80       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
81   auto module = CreateUnverifiedModule();
82   module->AddEntryComputation(builder.Build());
83 
84   TF_ASSERT_OK(verifier().Run(module.get()).status());
85 
86   negate->set_parent(nullptr);
87 
88   auto status = verifier().Run(module.get()).status();
89   ASSERT_FALSE(status.ok());
90   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
91 }
92 
TEST_F(HloVerifierTest,NullComputationParent)93 TEST_F(HloVerifierTest, NullComputationParent) {
94   HloComputation::Builder builder(TestName());
95   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
96   HloInstruction* param = builder.AddInstruction(
97       HloInstruction::CreateParameter(0, scalar_shape, "param"));
98   builder.AddInstruction(
99       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
100   auto module = CreateUnverifiedModule();
101   HloComputation* computation = module->AddEntryComputation(builder.Build());
102 
103   TF_ASSERT_OK(verifier().Run(module.get()).status());
104 
105   computation->set_parent(nullptr);
106 
107   auto status = verifier().Run(module.get()).status();
108   ASSERT_FALSE(status.ok());
109   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
110 }
111 
TEST_F(HloVerifierTest,DifferentOperandParents)112 TEST_F(HloVerifierTest, DifferentOperandParents) {
113   HloComputation::Builder builder(TestName());
114   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
115   HloInstruction* param = builder.AddInstruction(
116       HloInstruction::CreateParameter(0, scalar_shape, "param"));
117   HloInstruction* negate = builder.AddInstruction(
118       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
119   auto module = CreateUnverifiedModule();
120   module->AddEntryComputation(builder.Build());
121 
122   HloComputation::Builder emb_builder(TestName());
123   HloInstruction* emb_param = emb_builder.AddInstruction(
124       HloInstruction::CreateParameter(0, scalar_shape, "param"));
125   module->AddEmbeddedComputation(emb_builder.Build());
126 
127   TF_ASSERT_OK(verifier().Run(module.get()).status());
128   TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
129 
130   auto status = verifier().Run(module.get()).status();
131   ASSERT_FALSE(status.ok());
132   EXPECT_THAT(status.error_message(),
133               HasSubstr("is in a different computation"));
134 }
135 
TEST_F(HloVerifierTest,ResetsShapeVerifierState)136 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
137   HloComputation::Builder builder(TestName());
138   Shape s1 = ShapeUtil::MakeShape(F32, {1});
139   Shape s2 = ShapeUtil::MakeShape(F32, {2});
140 
141   HloInstruction* param =
142       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
143 
144   // Create an add instruction with the incorrect shape.
145   HloInstruction* add = builder.AddInstruction(
146       HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
147 
148   // In order to trigger the bug we're checking for, the instruction with the
149   // bad shape can't be the root of the computation.
150   builder.AddInstruction(
151       HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
152 
153   auto module = CreateUnverifiedModule();
154   module->AddEntryComputation(builder.Build());
155 
156   // Run the verifier twice.  It should fail both times, because it shouldn't
157   // carry state in its DFS visitor between runs.
158   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
159   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
160 }
161 
TEST_F(HloVerifierTest,CheckCallOperandParameterShapesMismatch)162 TEST_F(HloVerifierTest, CheckCallOperandParameterShapesMismatch) {
163   const char* const hlo_string = R"(
164   HloModule Module
165 
166   callme {
167     ROOT param = (s32[], f32[4]) parameter(0)
168   }
169 
170   ENTRY entry {
171     p0 = (f32[4], s32[]) parameter(0)
172     ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
173   }
174   )";
175   TF_ASSERT_OK_AND_ASSIGN(auto module,
176                           ParseAndReturnUnverifiedModule(hlo_string));
177 
178   auto status = verifier().Run(module.get()).status();
179   ASSERT_FALSE(status.ok());
180   EXPECT_THAT(status.error_message(),
181               HasSubstr("shape does not match parameter"));
182 }
183 
TEST_F(HloVerifierTest,CheckCallThreadMismatch)184 TEST_F(HloVerifierTest, CheckCallThreadMismatch) {
185   constexpr absl::string_view hlo = R"(
186     HloModule Module
187 
188     callme {
189       ROOT param = (s32[], f32[4]) parameter(0)
190     }, execution_thread="parallel_thread"
191 
192     ENTRY entry {
193       p0 = (s32[], f32[4]) parameter(0)
194       ROOT mycall = (s32[], f32[4]) call(p0), to_apply=callme
195     }
196   )";
197   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
198 
199   auto status = verifier().Run(module.get()).status();
200   ASSERT_FALSE(status.ok());
201   EXPECT_THAT(status.error_message(),
202               HasSubstr("expects parent computation thread name same as called "
203                         "computation's thread name"));
204 }
205 
TEST_F(HloVerifierTest,CheckConditionalOperandParameterShapesMismatch)206 TEST_F(HloVerifierTest, CheckConditionalOperandParameterShapesMismatch) {
207   const char* const hlo_string = R"(
208   HloModule Module
209 
210   true_branch {
211     tparam = (s32[], f32[4]) parameter(0)
212     ROOT tgte1 = f32[4] get-tuple-element(tparam), index=1
213   }
214 
215   false_branch {
216     fparam = (s32[], f32[4]) parameter(0)
217     ROOT fgte1 = f32[4] get-tuple-element(fparam), index=1
218   }
219 
220   ENTRY entry {
221     p0 = (f32[4], s32[]) parameter(0)
222     constant = pred[] constant(true)
223     ROOT conditional = f32[4] conditional(constant, p0, p0),
224       true_computation=true_branch, false_computation=false_branch
225   }
226   )";
227   TF_ASSERT_OK_AND_ASSIGN(auto module,
228                           ParseAndReturnUnverifiedModule(hlo_string));
229 
230   auto status = verifier().Run(module.get()).status();
231   ASSERT_FALSE(status.ok());
232   EXPECT_THAT(status.error_message(),
233               HasSubstr("shape does not match parameter"));
234 }
235 
TEST_F(HloVerifierTest,CheckConditionalBranchIndexOperandShape)236 TEST_F(HloVerifierTest, CheckConditionalBranchIndexOperandShape) {
237   const char* const hlo_string = R"(
238   HloModule Module
239 
240   branch0 {
241     tparam = f32[4] parameter(0)
242     ROOT tgte1 = f32[4] ceil(tparam)
243   }
244 
245   branch1 {
246     fparam = f32[4] parameter(0)
247     ROOT fgte1 = f32[4] floor(fparam)
248   }
249 
250   branch2 {
251     sparam = f32[4] parameter(0)
252     ROOT sgte1 = f32[4] ceil(sparam)
253   }
254 
255   ENTRY entry {
256     p0 = f32[4] parameter(0)
257     b0 = s32[] parameter(1)
258     ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
259       branch_computations={branch0, branch1, branch2}
260   }
261   )";
262   TF_ASSERT_OK_AND_ASSIGN(auto module,
263                           ParseAndReturnUnverifiedModule(hlo_string));
264   auto status = verifier().Run(module.get()).status();
265 
266   HloInstruction* condition = FindInstruction(module.get(), "b0");
267   *condition->mutable_shape() = ShapeUtil::MakeShape(F32, {});
268   status = verifier().Run(module.get()).status();
269   ASSERT_FALSE(status.ok());
270   EXPECT_THAT(
271       status.error_message(),
272       HasSubstr(
273           "first operand of indexed conditional must be a scalar of S32"));
274 
275   *condition->mutable_shape() = ShapeUtil::MakeShape(S32, {4});
276   status = verifier().Run(module.get()).status();
277   ASSERT_FALSE(status.ok());
278   EXPECT_THAT(status.error_message(),
279               HasSubstr("first operand of conditional must be a scalar"));
280 }
281 
TEST_F(HloVerifierTest,CheckConditionalBranchThread)282 TEST_F(HloVerifierTest, CheckConditionalBranchThread) {
283   const char* const hlo_string = R"(
284     HloModule Module
285 
286     branch0 {
287       tparam = f32[4] parameter(0)
288       ROOT tgte1 = f32[4] ceil(tparam)
289     }
290 
291     branch1 {
292       fparam = f32[4] parameter(0)
293       ROOT fgte1 = f32[4] floor(fparam)
294     }, execution_thread="parallel_thread"
295 
296     branch2 {
297       sparam = f32[4] parameter(0)
298       ROOT sgte1 = f32[4] ceil(sparam)
299     }
300 
301     ENTRY entry {
302       p0 = f32[4] parameter(0)
303       b0 = s32[] parameter(1)
304       ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
305         branch_computations={branch0, branch1, branch2}
306     }
307   )";
308   TF_ASSERT_OK_AND_ASSIGN(auto module,
309                           ParseAndReturnUnverifiedModule(hlo_string));
310   auto status = verifier().Run(module.get()).status();
311   EXPECT_THAT(status.error_message(),
312               HasSubstr("expects parent computation thread name same as called "
313                         "computation's thread name"));
314 }
315 
TEST_F(HloVerifierTest,CheckConditionalBranchContainsAsyncThread)316 TEST_F(HloVerifierTest, CheckConditionalBranchContainsAsyncThread) {
317   const char* const hlo_string = R"(
318     HloModule Module
319 
320     branch0 {
321       tparam = f32[4] parameter(0)
322       ROOT tgte1 = f32[4] ceil(tparam)
323     }
324 
325     branch1 {
326       fparam = f32[4] parameter(0)
327       %async-start = ((f32[4]), f32[4], s32[]) custom-call-start(f32[4] fparam), async_execution_thread="parallel_thread", custom_call_target="foo"
328       ROOT %async-done = f32[4] custom-call-done(((f32[4]), f32[4], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
329     }
330 
331     branch2 {
332       sparam = f32[4] parameter(0)
333       ROOT sgte1 = f32[4] ceil(sparam)
334     }
335 
336     ENTRY entry {
337       p0 = f32[4] parameter(0)
338       b0 = s32[] parameter(1)
339       ROOT conditional = f32[4] conditional(b0, p0, p0, p0),
340         branch_computations={branch0, branch1, branch2}
341     }
342   )";
343   TF_ASSERT_OK_AND_ASSIGN(auto module,
344                           ParseAndReturnUnverifiedModule(hlo_string));
345   TF_ASSERT_OK(verifier().Run(module.get()).status());
346 }
347 
TEST_F(HloVerifierTest,RngOpnd0NotScalar)348 TEST_F(HloVerifierTest, RngOpnd0NotScalar) {
349   const char* const hlo_string = R"(
350   HloModule Module
351 
352   ENTRY RngOpnd0NotScalar {
353    constant.0 = f32[] constant(0)
354    constant.1 = f16[2] constant({1, 3})
355    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[2] constant.1),
356     distribution=rng_uniform
357   }
358   )";
359   TF_ASSERT_OK_AND_ASSIGN(auto module,
360                           ParseAndReturnUnverifiedModule(hlo_string));
361 
362   auto status = verifier().Run(module.get()).status();
363   ASSERT_FALSE(status.ok());
364   EXPECT_THAT(status.error_message(), HasSubstr("Expected scalar type"));
365 }
366 
TEST_F(HloVerifierTest,RngOperandElementTypesDoNotMatch)367 TEST_F(HloVerifierTest, RngOperandElementTypesDoNotMatch) {
368   const char* const hlo_string = R"(
369   HloModule Module
370 
371   ENTRY RngOperandElementTypesNotMatch {
372    constant.0 = f32[] constant(0)
373    constant.1 = f16[] constant(1)
374    ROOT rng.0 = f32[10]{0} rng(f32[] constant.0, f16[] constant.1),
375     distribution=rng_normal
376   }
377   )";
378   TF_ASSERT_OK_AND_ASSIGN(auto module,
379                           ParseAndReturnUnverifiedModule(hlo_string));
380 
381   auto status = verifier().Run(module.get()).status();
382   ASSERT_FALSE(status.ok());
383   EXPECT_THAT(status.error_message(),
384               HasSubstr("Expected compatible element types"));
385 }
386 
TEST_F(HloVerifierTest,RngMixedPrecisionNotAllowed)387 TEST_F(HloVerifierTest, RngMixedPrecisionNotAllowed) {
388   const char* const hlo_string = R"(
389   HloModule Module
390 
391   ENTRY RngResultElementTypeNotMatch {
392    constant.0 = f32[] constant(0)
393    constant.1 = f32[] constant(1)
394    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
395     distribution=rng_normal
396   }
397   )";
398   TF_ASSERT_OK_AND_ASSIGN(auto module,
399                           ParseAndReturnUnverifiedModule(hlo_string));
400 
401   auto status = verifier().Run(module.get()).status();
402   ASSERT_FALSE(status.ok());
403   EXPECT_THAT(status.error_message(),
404               HasSubstr("Expected compatible element types"));
405 }
406 
TEST_F(HloVerifierTestAllowMixedPrecision,RngMixedPrecisionAllowed)407 TEST_F(HloVerifierTestAllowMixedPrecision, RngMixedPrecisionAllowed) {
408   const char* const hlo_string = R"(
409   HloModule Module
410 
411   ENTRY RngResultElementTypeNotMatch {
412    constant.0 = f32[] constant(0)
413    constant.1 = f32[] constant(1)
414    ROOT rng.0 = f16[10]{0} rng(f32[] constant.0, f32[] constant.1),
415     distribution=rng_normal
416   }
417   )";
418   TF_ASSERT_OK_AND_ASSIGN(auto module,
419                           ParseAndReturnVerifiedModule(hlo_string));
420 
421   auto status = verifier().Run(module.get()).status();
422   ASSERT_TRUE(status.ok());
423 }
424 
TEST_F(HloVerifierTest,RngElementTypeNotSupported)425 TEST_F(HloVerifierTest, RngElementTypeNotSupported) {
426   const char* const hlo_string = R"(
427   HloModule Module
428 
429   ENTRY RngElementTypeNotSupported {
430    constant.0 = s32[] constant(0)
431    constant.1 = s32[] constant(1)
432    ROOT rng.0 = s32[10]{0} rng(s32[] constant.0, s32[] constant.1),
433     distribution=rng_normal
434   }
435   )";
436   TF_ASSERT_OK_AND_ASSIGN(auto module,
437                           ParseAndReturnUnverifiedModule(hlo_string));
438 
439   auto status = verifier().Run(module.get()).status();
440   ASSERT_FALSE(status.ok());
441   EXPECT_THAT(status.error_message(), HasSubstr("Element type not supported"));
442 }
443 
TEST_F(HloVerifierTest,NegativeInteriorPaddingNotAllowed)444 TEST_F(HloVerifierTest, NegativeInteriorPaddingNotAllowed) {
445   // This testcase can't be written using textual HLO, because it doesn't parse
446   // negative interior padding.  That's probably a feature.  :)
447   HloComputation::Builder builder(TestName());
448   HloInstruction* param =
449       builder.AddInstruction(HloInstruction::CreateParameter(
450           0, ShapeUtil::MakeShape(F32, {100}), "param"));
451   PaddingConfig padding_config;
452   padding_config.add_dimensions()->set_interior_padding(-1);
453   builder.AddInstruction(HloInstruction::CreatePad(
454       ShapeUtil::MakeShape(F32, {100}), param,
455       builder.AddInstruction(
456           HloInstruction::CreateConstant(LiteralUtil::Zero(F32))),
457       padding_config));
458 
459   auto module = CreateUnverifiedModule();
460   module->AddEntryComputation(builder.Build());
461 
462   auto status = verifier().Run(module.get()).status();
463   ASSERT_FALSE(status.ok());
464   EXPECT_THAT(status.error_message(),
465               HasSubstr("Interior padding cannot be negative"));
466 }
467 
TEST_F(HloVerifierTest,PadNegativeInteriorDilationNotAllowed)468 TEST_F(HloVerifierTest, PadNegativeInteriorDilationNotAllowed) {
469   // This testcase can't be written using textual HLO, because it doesn't parse
470   // negative interior padding.  That's probably a feature.  :)
471   HloComputation::Builder builder(TestName());
472   HloInstruction* param =
473       builder.AddInstruction(HloInstruction::CreateParameter(
474           0, ShapeUtil::MakeShape(F32, {100}), "param"));
475   PaddingConfig padding_config;
476   padding_config.add_dimensions()->set_interior_padding(-1);
477   builder.AddInstruction(HloInstruction::CreatePad(
478       ShapeUtil::MakeShape(F32, {100}), param,
479       builder.AddInstruction(
480           HloInstruction::CreateConstant(LiteralUtil::Zero(F32).Clone())),
481       padding_config));
482 
483   auto module = CreateUnverifiedModule();
484   module->AddEntryComputation(builder.Build());
485 
486   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
487               HasSubstr("Interior padding cannot be negative"));
488 }
489 
TEST_F(HloVerifierTest,DotMixedPrecisionAllowed)490 TEST_F(HloVerifierTest, DotMixedPrecisionAllowed) {
491   static const char* const kDotHloString = R"(
492 HloModule module
493 ENTRY entry_computation {
494   a = f32[2,10] parameter(0)
495   b = bf16[10,2] parameter(1)
496   ROOT dot = f32[2,2] dot(a, b), lhs_contracting_dims={1}, rhs_contracting_dims={0}
497 })";
498   TF_ASSERT_OK_AND_ASSIGN(auto module,
499                           ParseAndReturnVerifiedModule(kDotHloString));
500 
501   auto status = verifier().Run(module.get()).status();
502   EXPECT_TRUE(status.ok()) << status;
503 }
504 
505 // Simple module containing a convolution as the root.
506 static const char* const kConvHloString = R"(
507 HloModule module
508 ENTRY entry_computation {
509   param0 = f16[128,128,56,56] parameter(0)
510   param1 = f16[3,3,128,128] parameter(1)
511   zero_f16 = f16[] constant(0)
512   ROOT conv = f16[128,128,28,28] convolution(param0, param1),
513     window={size=3x3 stride=2x2}, dim_labels=bf01_01io->bf01
514 })";
515 
TEST_F(HloVerifierTest,ConvNegativeWindowDilationNotAllowed)516 TEST_F(HloVerifierTest, ConvNegativeWindowDilationNotAllowed) {
517   TF_ASSERT_OK_AND_ASSIGN(auto module,
518                           ParseAndReturnUnverifiedModule(kConvHloString));
519   auto* conv = module->entry_computation()->root_instruction();
520   Window w = conv->window();
521   w.mutable_dimensions(0)->set_window_dilation(-1);
522   conv->set_window(w);
523 
524   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
525               HasSubstr("non-positive window dilation factor"));
526 }
527 
TEST_F(HloVerifierTest,ConvNegativeBaseDilationNotAllowed)528 TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
529   TF_ASSERT_OK_AND_ASSIGN(auto module,
530                           ParseAndReturnUnverifiedModule(kConvHloString));
531   auto* conv = module->entry_computation()->root_instruction();
532   Window w = conv->window();
533   w.mutable_dimensions(0)->set_base_dilation(-1);
534   conv->set_window(w);
535 
536   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
537               HasSubstr("non-positive base area dilation factor"));
538 }
539 
540 static const char* const kAddWithLayoutChangeHlo = R"(
541    HloModule AddWithLayoutChange
542     ENTRY AddWithLayoutChange {
543       par0 = f32[3,4]{1,0} parameter(0)
544       par1 = f32[3,4]{0,1} parameter(1)
545       ROOT add0 = f32[3,4]{1,0} add(par0,par1)
546     }
547   )";
548 
TEST_F(HloVerifierTest,AddWithLayoutChange)549 TEST_F(HloVerifierTest, AddWithLayoutChange) {
550   TF_ASSERT_OK_AND_ASSIGN(
551       auto module, ParseAndReturnVerifiedModule(kAddWithLayoutChangeHlo));
552   auto status = verifier().Run(module.get()).status();
553   ASSERT_TRUE(status.ok());
554 }
555 
TEST_F(HloVerifierTest,ScalarIndexDynamicSlice)556 TEST_F(HloVerifierTest, ScalarIndexDynamicSlice) {
557   const char* const kScalarIndexDynamicSlice = R"(
558     HloModule DynamicSlice_module
559 
560     ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
561       %original_parameter = s32[2,2,258] parameter(0)
562       %constant = s32[] constant(0)
563       %start_index = s32[] parameter(1)
564       ROOT %dynamic-slice = s32[2,2,258] dynamic-slice(s32[2,2,258] %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
565     }
566   )";
567 
568   HloModuleConfig config;
569   DebugOptions debug_options = config.debug_options();
570   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
571   config.set_debug_options(debug_options);
572 
573   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
574                                            kScalarIndexDynamicSlice, config));
575   auto status = verifier().Run(module.get()).status();
576   ASSERT_TRUE(status.ok());
577 }
578 
TEST_F(HloVerifierTest,ScalarIndexDynamicUpdateSlice)579 TEST_F(HloVerifierTest, ScalarIndexDynamicUpdateSlice) {
580   const char* const kScalarIndexDynamicSlice = R"(
581     HloModule DynamicUpdateSlice_module
582 
583     ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
584       %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
585       %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
586       %start_index.0 = s32[] parameter(2)
587       %start_index.1 = s32[] parameter(3)
588       %start_index.2 = s32[] parameter(4)
589       %start_index.3 = s32[] parameter(5)
590       ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
591     }
592   )";
593 
594   HloModuleConfig config;
595   DebugOptions debug_options = config.debug_options();
596   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
597   config.set_debug_options(debug_options);
598 
599   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
600                                            kScalarIndexDynamicSlice, config));
601   auto status = verifier().Run(module.get()).status();
602   ASSERT_TRUE(status.ok());
603 }
604 
TEST_F(HloVerifierTestAllowMixedPrecision,DynamicUpdateSliceMixedPrecision)605 TEST_F(HloVerifierTestAllowMixedPrecision, DynamicUpdateSliceMixedPrecision) {
606   const char* const kDynamicUpdateSliceMixedPrecision = R"(
607     HloModule kDynamicUpdateSliceMixedPrecision
608 
609     ENTRY %entry (parameter.0: f32[32,511,2048], parameter.1: bf16[32,511,512], parameter.2: s32[], parameter.3: s32[], parameter.4: s32[]) -> bf16[32,511,2048] {
610       %parameter.0 = f32[32,511,2048] parameter(0)
611       %parameter.1 = bf16[32,511,512] parameter(1)
612       %parameter.2 = s32[] parameter(2)
613       %parameter.3 = s32[] parameter(3)
614       %parameter.4 = s32[] parameter(4)
615       ROOT %dus = bf16[32,511,2048] dynamic-update-slice(f32[32,511,2048] %parameter.0, bf16[32,511,512] %parameter.1, s32[] %parameter.2, s32[] %parameter.3, s32[] %parameter.4)
616     }
617   )";
618   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
619                                            kDynamicUpdateSliceMixedPrecision));
620   auto status = verifier().Run(module.get()).status();
621   ASSERT_FALSE(status.ok());
622   EXPECT_THAT(status.error_message(),
623               HasSubstr("Expected instruction to have shape equal to "
624                         "f32[32,511,2048], actual shape is bf16[32,511,2048]"));
625 }
626 
TEST_F(HloVerifierTestLayoutSensitive,AddWithLayoutChangeNotAllowed)627 TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
628   TF_ASSERT_OK_AND_ASSIGN(
629       auto module, ParseAndReturnUnverifiedModule(kAddWithLayoutChangeHlo));
630   auto status = verifier().Run(module.get()).status();
631   ASSERT_FALSE(status.ok());
632   EXPECT_THAT(status.error_message(),
633               HasSubstr("Instruction shouldn't change layouts"));
634 }
635 
TEST_F(HloVerifierTestLayoutSensitive,SliceWithLayoutChangeNotAllowed)636 TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
637   const char* const kSliceWithLayoutChangeHlo = R"(
638    HloModule SliceWithLayoutChange
639     ENTRY SliceWithLayoutChange {
640       par0 = f32[4,5]{0,1} parameter(0)
641       par1 = s32[] parameter(1)
642       par2 = s32[] parameter(2)
643       ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1, par2),
644         dynamic_slice_sizes={3,4}
645     }
646   )";
647   TF_ASSERT_OK_AND_ASSIGN(
648       auto module, ParseAndReturnUnverifiedModule(kSliceWithLayoutChangeHlo));
649   auto status = verifier().Run(module.get()).status();
650   ASSERT_FALSE(status.ok());
651   EXPECT_THAT(status.error_message(),
652               HasSubstr("Instruction shouldn't change layouts"));
653 }
654 
TEST_F(HloVerifierTestLayoutSensitive,ConcatWithLayoutChangeNotAllowed)655 TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
656   const char* const kConcatWithLayoutChangeHlo = R"(
657    HloModule ConcatWithLayoutChange
658    ENTRY ConcatWithLayoutChange {
659       par0 = f32[3,5]{0,1} parameter(0)
660       par1 = f32[3,3]{1,0} parameter(1)
661       ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
662         dimensions={1}
663    }
664   )";
665   TF_ASSERT_OK_AND_ASSIGN(
666       auto module, ParseAndReturnUnverifiedModule(kConcatWithLayoutChangeHlo));
667   auto status = verifier().Run(module.get()).status();
668   ASSERT_FALSE(status.ok());
669   EXPECT_THAT(status.error_message(),
670               HasSubstr("Instruction shouldn't change layouts"));
671 }
672 
TEST_F(HloVerifierTestLayoutSensitive,BitcastNeedsSameNumberOfElements)673 TEST_F(HloVerifierTestLayoutSensitive, BitcastNeedsSameNumberOfElements) {
674   const char* const hlo_string = R"(
675   HloModule Module
676 
677   ENTRY BitcastNeedsToBeNoOp {
678    constant.0 = f32[2] constant({0.0, 0.0})
679    ROOT bitcast = f32[3] bitcast(constant.0)
680   }
681   )";
682   TF_ASSERT_OK_AND_ASSIGN(auto module,
683                           ParseAndReturnUnverifiedModule(hlo_string));
684 
685   auto status = verifier().Run(module.get()).status();
686   ASSERT_FALSE(status.ok());
687   EXPECT_THAT(status.error_message(),
688               HasSubstr("Bitcast cannot have different shape sizes of output "
689                         "(12) and operand (8)"));
690 }
691 
TEST_F(HloVerifierTest,SelectMixedPrecisionNotAllowed)692 TEST_F(HloVerifierTest, SelectMixedPrecisionNotAllowed) {
693   const char* const hlo_string = R"(
694   HloModule Module
695 
696   ENTRY SelectMixedPrecisionNotAllowed {
697    p0 = pred[32] parameter(0)
698    p1 = f32[32] parameter(1)
699    p2 = bf16[32] parameter(2)
700    ROOT select = f32[32] select(p0, p1, p2)
701   }
702   )";
703   TF_ASSERT_OK_AND_ASSIGN(auto module,
704                           ParseAndReturnUnverifiedModule(hlo_string));
705 
706   auto status = verifier().Run(module.get()).status();
707   ASSERT_FALSE(status.ok());
708   EXPECT_THAT(status.error_message(),
709               HasSubstr("Seen floating point types of different precisions"));
710 }
711 
TEST_F(HloVerifierTestAllowMixedPrecision,SelectMixedPrecisionAllowed)712 TEST_F(HloVerifierTestAllowMixedPrecision, SelectMixedPrecisionAllowed) {
713   const char* const hlo_string = R"(
714   HloModule Module
715 
716   ENTRY SelectMixedPrecisionAllowed {
717    p0 = pred[32] parameter(0)
718    p1 = f32[32] parameter(1)
719    p2 = bf16[32] parameter(2)
720    ROOT select = f32[32] select(p0, p1, p2)
721   }
722   )";
723   TF_ASSERT_OK_AND_ASSIGN(auto module,
724                           ParseAndReturnVerifiedModule(hlo_string));
725 
726   auto status = verifier().Run(module.get()).status();
727   ASSERT_TRUE(status.ok());
728 }
729 
TEST_F(HloVerifierTest,SelectTupleNotAllowed)730 TEST_F(HloVerifierTest, SelectTupleNotAllowed) {
731   const char* const hlo_string = R"(
732   HloModule Module
733 
734   ENTRY SelectWithTuple {
735     p0 = (f32[], f32[]) parameter(0)
736     p1 = (f32[], f32[]) parameter(1)
737     p2 = pred[] parameter(2)
738     ROOT select = (f32[], f32[]) select(p2, p0, p1)
739   }
740   )";
741   TF_ASSERT_OK_AND_ASSIGN(auto module,
742                           ParseAndReturnUnverifiedModule(hlo_string));
743 
744   auto status = verifier().Run(module.get()).status();
745   ASSERT_FALSE(status.ok());
746   EXPECT_THAT(status.error_message(),
747               HasSubstr("Expected array argument for select"));
748 }
749 
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDone)750 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDone) {
751   const char* const hlo_string = R"(
752   HloModule Module
753 
754   ENTRY CopyStartAndCopyDone {
755     p0 = f32[2,3]{1,0:S(1)} parameter(0)
756     copy-start = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
757     ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
758   }
759   )";
760   TF_ASSERT_OK_AND_ASSIGN(auto module,
761                           ParseAndReturnVerifiedModule(hlo_string));
762 
763   auto status = verifier().Run(module.get()).status();
764   ASSERT_TRUE(status.ok());
765 }
766 
TEST_F(HloVerifierTestLayoutSensitive,CopyStartAndCopyDoneWrongLayout)767 TEST_F(HloVerifierTestLayoutSensitive, CopyStartAndCopyDoneWrongLayout) {
768   const char* const hlo_string = R"(
769   HloModule Module
770 
771   ENTRY CopyStartAndCopyDone {
772     p0 = f32[2,3]{1,0:S(1)} parameter(0)
773     copy-start = (f32[2,3]{0,1:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(p0)
774     ROOT copy-done = f32[2,3]{1,0:S(2)} copy-done(copy-start)
775   }
776   )";
777   TF_ASSERT_OK_AND_ASSIGN(auto module,
778                           ParseAndReturnUnverifiedModule(hlo_string));
779 
780   auto status = verifier().Run(module.get()).status();
781   ASSERT_FALSE(status.ok());
782   EXPECT_THAT(status.error_message(),
783               HasSubstr("Expected instruction to have shape equal to"));
784 }
785 
TEST_F(HloVerifierTest,CopyStartAndCopyDoneWrongType)786 TEST_F(HloVerifierTest, CopyStartAndCopyDoneWrongType) {
787   const char* const hlo_string = R"(
788   HloModule Module
789 
790   ENTRY CopyStartAndCopyDone {
791     p0 = f32[2,3] parameter(0)
792     copy-start = f32[2,3] copy-start(p0)
793     ROOT copy-done = f32[2,3] copy-done(copy-start)
794   }
795   )";
796   TF_ASSERT_OK_AND_ASSIGN(auto module,
797                           ParseAndReturnUnverifiedModule(hlo_string));
798 
799   auto status = verifier().Run(module.get()).status();
800   ASSERT_FALSE(status.ok());
801   EXPECT_THAT(status.error_message(),
802               HasSubstr("Expected instruction to have shape equal to "
803                         "(f32[2,3], f32[2,3], u32[])"));
804 }
805 
TEST_F(HloVerifierTest,CopyStartMultipleCopyDone)806 TEST_F(HloVerifierTest, CopyStartMultipleCopyDone) {
807   const char* const hlo_string = R"(
808   HloModule Module
809 
810   ENTRY CopyStartAndCopyDone {
811     p0 = f32[2,3] parameter(0)
812     copy-start = (f32[2,3], f32[2,3], u32[]) copy-start(p0)
813     copy-done.1 = f32[2,3] copy-done(copy-start)
814     copy-done.2 = f32[2,3] copy-done(copy-start)
815     ROOT tuple = (f32[2,3], f32[2,3]) tuple(copy-done.1, copy-done.2)
816   }
817   )";
818   TF_ASSERT_OK_AND_ASSIGN(auto module,
819                           ParseAndReturnUnverifiedModule(hlo_string));
820 
821   auto status = verifier().Run(module.get()).status();
822   ASSERT_FALSE(status.ok());
823   EXPECT_THAT(
824       status.error_message(),
825       HasSubstr("copy-start instruction requires one consumer, found 2"));
826 }
827 
TEST_F(HloVerifierTest,CopyDoneNoCopyStart)828 TEST_F(HloVerifierTest, CopyDoneNoCopyStart) {
829   const char* const hlo_string = R"(
830   HloModule Module
831 
832   ENTRY CopyStartAndCopyDone {
833     p0 = f32[2,3] parameter(0)
834     p1 = u32[] parameter(1)
835     tuple = (f32[2,3], f32[2,3], u32[]) tuple(p0, p0, p1)
836     ROOT copy-done = f32[2,3] copy-done(tuple)
837   }
838   )";
839   TF_ASSERT_OK_AND_ASSIGN(auto module,
840                           ParseAndReturnUnverifiedModule(hlo_string));
841 
842   auto status = verifier().Run(module.get()).status();
843   ASSERT_FALSE(status.ok());
844   EXPECT_THAT(status.error_message(),
845               HasSubstr("The operand of a copy-done instruction needs to be "
846                         "copy-start, found tuple"));
847 }
848 
TEST_F(HloVerifierTestLayoutSensitive,AsyncStartAndAsyncDone)849 TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncDone) {
850   const char* const hlo_string = R"(
851   HloModule Module
852 
853   ENTRY AsyncStartAndAsyncDone {
854     p0 = f32[2,3]{1,0:S(1)} parameter(0)
855     async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo"
856     ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-start), custom_call_target="foo"
857   }
858   )";
859   TF_ASSERT_OK_AND_ASSIGN(auto module,
860                           ParseAndReturnVerifiedModule(hlo_string));
861 
862   auto status = verifier().Run(module.get()).status();
863   ASSERT_TRUE(status.ok());
864 }
865 
TEST_F(HloVerifierTestLayoutSensitive,AsyncStartAndAsyncUpdateAndAsyncDone)866 TEST_F(HloVerifierTestLayoutSensitive, AsyncStartAndAsyncUpdateAndAsyncDone) {
867   const char* const hlo_string = R"(
868   HloModule Module
869 
870   ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
871     p0 = f32[2,3]{1,0:S(1)} parameter(0)
872     async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), custom_call_target="foo"
873     async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo"
874     async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), custom_call_target="foo"
875     ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), custom_call_target="foo"
876   }
877   )";
878   TF_ASSERT_OK_AND_ASSIGN(auto module,
879                           ParseAndReturnVerifiedModule(hlo_string));
880 
881   auto status = verifier().Run(module.get()).status();
882   ASSERT_TRUE(status.ok());
883 }
884 
TEST_F(HloVerifierTestLayoutSensitive,AsyncStartAndAsyncUpdateAndAsyncDoneWithThreadName)885 TEST_F(HloVerifierTestLayoutSensitive,
886        AsyncStartAndAsyncUpdateAndAsyncDoneWithThreadName) {
887   const char* const hlo_string = R"(
888   HloModule Module
889 
890   ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
891     p0 = f32[2,3]{1,0:S(1)} parameter(0)
892     async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_execution_thread="parallel_thread", custom_call_target="foo"
893     async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_execution_thread="parallel_thread", custom_call_target="foo"
894     async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_execution_thread="parallel_thread", custom_call_target="foo"
895     ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_execution_thread="parallel_thread", custom_call_target="foo"
896   }
897   )";
898   TF_ASSERT_OK_AND_ASSIGN(auto module,
899                           ParseAndReturnVerifiedModule(hlo_string));
900 
901   auto status = verifier().Run(module.get()).status();
902   ASSERT_TRUE(status.ok());
903 }
904 
TEST_F(HloVerifierTest,AsyncStartAndAsyncDoneWrongType)905 TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongType) {
906   const char* const hlo_string = R"(
907   HloModule Module
908 
909   ENTRY AsyncStartAndAsyncDone {
910     p0 = f32[2,3] parameter(0)
911     async-start = ((f32[2,3]), f32[3,2], u32[]) custom-call-start(p0), custom_call_target="foo"
912     ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
913   }
914   )";
915   TF_ASSERT_OK_AND_ASSIGN(auto module,
916                           ParseAndReturnUnverifiedModule(hlo_string));
917 
918   auto status = verifier().Run(module.get()).status();
919   ASSERT_FALSE(status.ok());
920   EXPECT_THAT(status.error_message(),
921               HasSubstr("async-done expects the async shape at index {1} to "
922                         "match the async computation root shape"));
923 }
924 
TEST_F(HloVerifierTest,AsyncStartAndAsyncDoneWrongThreadName)925 TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongThreadName) {
926   const char* const hlo_string = R"(
927   HloModule Module
928 
929   ENTRY AsyncStartAndAsyncDone {
930     p0 = f32[2,3] parameter(0)
931     async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), async_execution_thread="parallel_thread", custom_call_target="foo"
932     ROOT async-done = f32[2,3] custom-call-done(async-start), async_execution_thread="main_thread", custom_call_target="bar"
933   }
934   )";
935   TF_ASSERT_OK_AND_ASSIGN(auto module,
936                           ParseAndReturnUnverifiedModule(hlo_string));
937 
938   auto status = verifier().Run(module.get()).status();
939   ASSERT_FALSE(status.ok());
940   EXPECT_THAT(status.error_message(),
941               HasSubstr("thread name (main_thread vs parallel_thread)."));
942 }
943 
TEST_F(HloVerifierTest,AsyncStartAndAsyncDoneWrongAttr)944 TEST_F(HloVerifierTest, AsyncStartAndAsyncDoneWrongAttr) {
945   const char* const hlo_string = R"(
946   HloModule Module
947 
948   ENTRY AsyncStartAndAsyncDone {
949     p0 = f32[2,3] parameter(0)
950     async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
951     ROOT async-done = f32[2,3] custom-call-done(async-start), custom_call_target="bar"
952   }
953   )";
954   TF_ASSERT_OK_AND_ASSIGN(auto module,
955                           ParseAndReturnUnverifiedModule(hlo_string));
956 
957   auto status = verifier().Run(module.get()).status();
958   ASSERT_FALSE(status.ok());
959   EXPECT_THAT(status.error_message(),
960               HasSubstr("async-done expects its wrapped async computation to "
961                         "be identical to its operand's"));
962 }
963 
TEST_F(HloVerifierTest,AsyncStartMultipleAsyncDone)964 TEST_F(HloVerifierTest, AsyncStartMultipleAsyncDone) {
965   const char* const hlo_string = R"(
966   HloModule Module
967 
968   ENTRY AsyncStartAndAsyncDone {
969     p0 = f32[2,3] parameter(0)
970     async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
971     async-done.1 = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
972     async-done.2 = f32[2,3] custom-call-done(async-start), custom_call_target="foo"
973     ROOT tuple = (f32[2,3], f32[2,3]) tuple(async-done.1, async-done.2)
974   }
975   )";
976   TF_ASSERT_OK_AND_ASSIGN(auto module,
977                           ParseAndReturnUnverifiedModule(hlo_string));
978 
979   auto status = verifier().Run(module.get()).status();
980   ASSERT_FALSE(status.ok());
981   EXPECT_THAT(
982       status.error_message(),
983       HasSubstr("async-start instruction requires one consumer, found 2"));
984 }
985 
TEST_F(HloVerifierTest,AsyncStartNoAsyncDone)986 TEST_F(HloVerifierTest, AsyncStartNoAsyncDone) {
987   const char* const hlo_string = R"(
988   HloModule Module
989 
990   ENTRY AsyncStartAndAsyncDone {
991     p0 = f32[2,3] parameter(0)
992     ROOT async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
993   }
994   )";
995   TF_ASSERT_OK_AND_ASSIGN(auto module,
996                           ParseAndReturnUnverifiedModule(hlo_string));
997 
998   auto status = verifier().Run(module.get()).status();
999   ASSERT_FALSE(status.ok());
1000   EXPECT_THAT(
1001       status.error_message(),
1002       HasSubstr("async-start instruction requires one consumer, found 0"));
1003 }
1004 
TEST_F(HloVerifierTest,AsyncStartAndAsyncUpdateNoAsyncDone)1005 TEST_F(HloVerifierTest, AsyncStartAndAsyncUpdateNoAsyncDone) {
1006   const char* const hlo_string = R"(
1007   HloModule Module
1008 
1009   ENTRY AsyncStartAndAsyncDone {
1010     p0 = f32[2,3] parameter(0)
1011     async-start = ((f32[2,3]), f32[2,3], u32[]) custom-call-start(p0), custom_call_target="foo"
1012     ROOT async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(async-start), custom_call_target="foo"
1013   }
1014   )";
1015   TF_ASSERT_OK_AND_ASSIGN(auto module,
1016                           ParseAndReturnUnverifiedModule(hlo_string));
1017 
1018   auto status = verifier().Run(module.get()).status();
1019   ASSERT_FALSE(status.ok());
1020   EXPECT_THAT(
1021       status.error_message(),
1022       HasSubstr("async-update instruction requires one consumer, found 0"));
1023 }
1024 
TEST_F(HloVerifierTest,AsyncDoneNoAsyncStart)1025 TEST_F(HloVerifierTest, AsyncDoneNoAsyncStart) {
1026   const char* const hlo_string = R"(
1027   HloModule Module
1028 
1029   ENTRY AsyncStartAndAsyncDone {
1030     p0 = f32[2,3] parameter(0)
1031     p1 = u32[] parameter(1)
1032     tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1)
1033     ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo"
1034   }
1035   )";
1036   TF_ASSERT_OK_AND_ASSIGN(auto module,
1037                           ParseAndReturnUnverifiedModule(hlo_string));
1038 
1039   auto status = verifier().Run(module.get()).status();
1040   ASSERT_FALSE(status.ok());
1041   EXPECT_THAT(status.error_message(),
1042               HasSubstr("The operand of a async-done instruction needs to be "
1043                         "async-start or async-update, found tuple"));
1044 }
1045 
TEST_F(HloVerifierTest,AsyncUpdateAndAsyncDoneNoAsyncStart)1046 TEST_F(HloVerifierTest, AsyncUpdateAndAsyncDoneNoAsyncStart) {
1047   const char* const hlo_string = R"(
1048   HloModule Module
1049 
1050   ENTRY AsyncStartAndAsyncDone {
1051     p0 = f32[2,3] parameter(0)
1052     p1 = u32[] parameter(1)
1053     tuple = ((f32[2,3]), f32[2,3], u32[]) tuple(p0, p0, p1)
1054     async-update = ((f32[2,3]), f32[2,3], u32[]) custom-call-update(tuple), custom_call_target="foo"
1055     ROOT async-done = f32[2,3] custom-call-done(tuple), custom_call_target="foo"
1056   }
1057   )";
1058   TF_ASSERT_OK_AND_ASSIGN(auto module,
1059                           ParseAndReturnUnverifiedModule(hlo_string));
1060 
1061   auto status = verifier().Run(module.get()).status();
1062   ASSERT_FALSE(status.ok());
1063   EXPECT_THAT(status.error_message(),
1064               HasSubstr("The operand of a async-update instruction needs to be "
1065                         "async-start or async-update, found tuple"));
1066 }
1067 
TEST_F(HloVerifierTest,AsyncOpComputationParamWrongType)1068 TEST_F(HloVerifierTest, AsyncOpComputationParamWrongType) {
1069   const char* const hlo_string = R"(
1070   HloModule Module
1071 
1072   async_computation {
1073     p = f32[2,3] parameter(0)
1074     ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1075   }
1076 
1077   ENTRY AsyncStartAndAsyncDone {
1078     p0 = f32[2,3] parameter(0)
1079     async-start = ((f32[3,2]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1080     ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1081   }
1082   )";
1083   TF_ASSERT_OK_AND_ASSIGN(auto module,
1084                           ParseAndReturnUnverifiedModule(hlo_string));
1085 
1086   auto status = verifier().Run(module.get()).status();
1087   ASSERT_FALSE(status.ok());
1088   EXPECT_THAT(status.error_message(),
1089               HasSubstr("async-start expects the async shape at index {0} to "
1090                         "match async computation parameter shape"));
1091 }
1092 
TEST_F(HloVerifierTest,AsyncOpComputationRootWrongType)1093 TEST_F(HloVerifierTest, AsyncOpComputationRootWrongType) {
1094   const char* const hlo_string = R"(
1095   HloModule Module
1096 
1097   async_computation {
1098     p = f32[2,3] parameter(0)
1099     ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1100   }
1101 
1102   ENTRY AsyncStartAndAsyncDone {
1103     p0 = f32[2,3] parameter(0)
1104     async-start = ((f32[2,3]), f32[2,3], u32[]) async-start(p0), calls=async_computation
1105     ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1106   }
1107   )";
1108   TF_ASSERT_OK_AND_ASSIGN(auto module,
1109                           ParseAndReturnUnverifiedModule(hlo_string));
1110 
1111   auto status = verifier().Run(module.get()).status();
1112   ASSERT_FALSE(status.ok());
1113   EXPECT_THAT(status.error_message(),
1114               HasSubstr("async-start expects the async shape at index {1} to "
1115                         "match the async computation root shape"));
1116 }
1117 
TEST_F(HloVerifierTest,AsyncOpTupleWrongType)1118 TEST_F(HloVerifierTest, AsyncOpTupleWrongType) {
1119   const char* const hlo_string = R"(
1120   HloModule Module
1121 
1122   async_computation {
1123     p = f32[2,3] parameter(0)
1124     ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1125   }
1126 
1127   ENTRY AsyncStartAndAsyncDone {
1128     p0 = f32[2,3] parameter(0)
1129     async-start = ((f32[2,3])) async-start(p0), calls=async_computation
1130     ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1131   }
1132   )";
1133   TF_ASSERT_OK_AND_ASSIGN(auto module,
1134                           ParseAndReturnUnverifiedModule(hlo_string));
1135 
1136   auto status = verifier().Run(module.get()).status();
1137   ASSERT_FALSE(status.ok());
1138   EXPECT_THAT(status.error_message(),
1139               HasSubstr("async-start expects the async shape to be a tuple of "
1140                         "at least two elements"));
1141 }
1142 
TEST_F(HloVerifierTest,AsyncStartOperandWrongType)1143 TEST_F(HloVerifierTest, AsyncStartOperandWrongType) {
1144   const char* const hlo_string = R"(
1145   HloModule Module
1146 
1147   async_computation {
1148     p = f32[2,3] parameter(0)
1149     ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1150   }
1151 
1152   ENTRY AsyncStartAndAsyncDone {
1153     p0 = f32[3,2] parameter(0)
1154     async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1155     ROOT async-done = f32[3,2] async-done(async-start), calls=async_computation
1156   }
1157   )";
1158   TF_ASSERT_OK_AND_ASSIGN(auto module,
1159                           ParseAndReturnUnverifiedModule(hlo_string));
1160 
1161   auto status = verifier().Run(module.get()).status();
1162   ASSERT_FALSE(status.ok());
1163   EXPECT_THAT(status.error_message(),
1164               HasSubstr("async-start expects the shape of operand 0 to match "
1165                         "the async shape at index {0}"));
1166 }
1167 
TEST_F(HloVerifierTest,AsyncDoneOutputWrongType)1168 TEST_F(HloVerifierTest, AsyncDoneOutputWrongType) {
1169   const char* const hlo_string = R"(
1170   HloModule Module
1171 
1172   async_computation {
1173     p = f32[2,3] parameter(0)
1174     ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1175   }
1176 
1177   ENTRY AsyncStartAndAsyncDone {
1178     p0 = f32[2,3] parameter(0)
1179     async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1180     ROOT async-done = f32[2,3] async-done(async-start), calls=async_computation
1181   }
1182   )";
1183   TF_ASSERT_OK_AND_ASSIGN(auto module,
1184                           ParseAndReturnUnverifiedModule(hlo_string));
1185 
1186   auto status = verifier().Run(module.get()).status();
1187   ASSERT_FALSE(status.ok());
1188   EXPECT_THAT(status.error_message(),
1189               HasSubstr("async-done expects the shape of output to match the "
1190                         "async shape at index {1}"));
1191 }
1192 
TEST_F(HloVerifierTest,AsyncUpdateWrongType)1193 TEST_F(HloVerifierTest, AsyncUpdateWrongType) {
1194   const char* const hlo_string = R"(
1195   HloModule Module
1196 
1197   async_computation {
1198     p = f32[2,3] parameter(0)
1199     ROOT custom-call = f32[3,2] custom-call(p), custom_call_target="foo"
1200   }
1201 
1202   ENTRY AsyncStartAndAsyncDone {
1203     p0 = f32[2,3] parameter(0)
1204     async-start = ((f32[2,3]), f32[3,2], u32[]) async-start(p0), calls=async_computation
1205     async-update = ((f32[3,2]), f32[3,2], u32[]) async-update(async-start), calls=async_computation
1206     ROOT async-done = f32[3,2] async-done(async-update), calls=async_computation
1207   }
1208   )";
1209   TF_ASSERT_OK_AND_ASSIGN(auto module,
1210                           ParseAndReturnUnverifiedModule(hlo_string));
1211 
1212   auto status = verifier().Run(module.get()).status();
1213   ASSERT_FALSE(status.ok());
1214   EXPECT_THAT(
1215       status.error_message(),
1216       HasSubstr(
1217           "async-update expects the shape of operand and output to match"));
1218 }
1219 
TEST_F(HloVerifierTestLayoutSensitive,AsyncDoneWrongGroupId)1220 TEST_F(HloVerifierTestLayoutSensitive, AsyncDoneWrongGroupId) {
1221   const char* const hlo_string = R"(
1222   HloModule Module
1223 
1224   ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
1225     p0 = f32[2,3]{1,0:S(1)} parameter(0)
1226     async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo"
1227     async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), async_group_id=0, custom_call_target="foo"
1228     async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo"
1229     ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=1, custom_call_target="foo"
1230   }
1231   )";
1232   TF_ASSERT_OK_AND_ASSIGN(auto module,
1233                           ParseAndReturnUnverifiedModule(hlo_string));
1234 
1235   auto status = verifier().Run(module.get()).status();
1236   ASSERT_FALSE(status.ok());
1237   EXPECT_THAT(status.error_message(),
1238               HasSubstr("async-done expects its operand to have the same group "
1239                         "id (1 vs 0)."));
1240 }
1241 
TEST_F(HloVerifierTestLayoutSensitive,AsyncUpdateWrongGroupId)1242 TEST_F(HloVerifierTestLayoutSensitive, AsyncUpdateWrongGroupId) {
1243   const char* const hlo_string = R"(
1244   HloModule Module
1245 
1246   ENTRY AsyncStartAndAsyncUpdateAndAsyncDone {
1247     p0 = f32[2,3]{1,0:S(1)} parameter(0)
1248     async-start = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-start(p0), async_group_id=0, custom_call_target="foo"
1249     async-update.1 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-start), custom_call_target="foo"
1250     async-update.2 = ((f32[2,3]{1,0:S(1)}), f32[2,3]{1,0:S(2)}, u32[]) custom-call-update(async-update.1), async_group_id=0, custom_call_target="foo"
1251     ROOT async-done = f32[2,3]{1,0:S(2)} custom-call-done(async-update.2), async_group_id=0, custom_call_target="foo"
1252   }
1253   )";
1254   TF_ASSERT_OK_AND_ASSIGN(auto module,
1255                           ParseAndReturnUnverifiedModule(hlo_string));
1256 
1257   auto status = verifier().Run(module.get()).status();
1258   ASSERT_FALSE(status.ok());
1259   EXPECT_THAT(status.error_message(),
1260               HasSubstr("async-update expects its operand to have the same "
1261                         "group id (none vs 0)."));
1262 }
1263 
TEST_F(HloVerifierTest,IotaNonArrayResult)1264 TEST_F(HloVerifierTest, IotaNonArrayResult) {
1265   const char* const hlo_string = R"(
1266   HloModule IotaTupleResult
1267 
1268   ENTRY  kernelEntry {
1269     ROOT iota = () iota(), iota_dimension=24
1270   }
1271   )";
1272 
1273   TF_ASSERT_OK_AND_ASSIGN(auto module,
1274                           ParseAndReturnUnverifiedModule(hlo_string));
1275 
1276   auto status = verifier().Run(module.get()).status();
1277   ASSERT_FALSE(status.ok());
1278   EXPECT_THAT(status.error_message(),
1279               HasSubstr("does not support non-array result"));
1280 }
1281 
TEST_F(HloVerifierTest,IotaNegativeDimension)1282 TEST_F(HloVerifierTest, IotaNegativeDimension) {
1283   const char* const hlo_string = R"(
1284   HloModule IotaTupleResult
1285 
1286   ENTRY  kernelEntry {
1287     ROOT iota = s32[128,1001]{1,0} iota(), iota_dimension=-1
1288   }
1289   )";
1290 
1291   TF_ASSERT_OK_AND_ASSIGN(auto module,
1292                           ParseAndReturnUnverifiedModule(hlo_string));
1293 
1294   auto status = verifier().Run(module.get()).status();
1295   ASSERT_FALSE(status.ok());
1296   EXPECT_THAT(status.error_message(), HasSubstr("negative"));
1297 }
1298 
TEST_F(HloVerifierTest,IotaPredResultNotAllowed)1299 TEST_F(HloVerifierTest, IotaPredResultNotAllowed) {
1300   const char* const hlo_string = R"(
1301   HloModule IotaPredResult
1302 
1303   ENTRY  kernelEntry {
1304     ROOT iota = pred[128] iota(), iota_dimension=0
1305   }
1306   )";
1307 
1308   TF_ASSERT_OK_AND_ASSIGN(auto module,
1309                           ParseAndReturnUnverifiedModule(hlo_string));
1310 
1311   auto status = verifier().Run(module.get()).status();
1312   ASSERT_FALSE(status.ok());
1313   EXPECT_THAT(status.error_message(), HasSubstr("got PRED"));
1314 }
1315 
1316 static const char* const kMapOperandComputationMismatchHlo = R"(
1317   HloModule MapOperandComputationMismatch
1318 
1319   Computation {
1320     param0 = f32[] parameter(0)
1321     constant = f32[] constant(1)
1322     ROOT add = f32[] add(param0, constant)
1323   }
1324 
1325   ENTRY kernelEntry {
1326   param = f64[] parameter(0)
1327   ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
1328 })";
1329 
TEST_F(HloVerifierTest,MapOperandComputationMismatch)1330 TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
1331   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(
1332                                            kMapOperandComputationMismatchHlo));
1333   auto status = verifier().Run(module.get()).status();
1334   ASSERT_FALSE(status.ok());
1335   EXPECT_THAT(
1336       status.error_message(),
1337       HasSubstr(
1338           "Shape mismatch between to_apply computation parameter and operand"));
1339 }
1340 
TEST_F(HloVerifierTestAllowMixedPrecision,MapOperandComputationMismatch)1341 TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
1342   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(
1343                                            kMapOperandComputationMismatchHlo));
1344   auto status = verifier().Run(module.get()).status();
1345   ASSERT_TRUE(status.ok());
1346 }
1347 
1348 static const char* const kReduceOperandComputationMismatchHlo = R"(
1349   HloModule ReduceOperandComputationMismatch
1350   computation {
1351     x = f32[] parameter(0)
1352     y = f32[] parameter(1)
1353     ROOT add = f32[] add(x, y)
1354   }
1355 
1356   ENTRY kernelEntry {
1357     arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
1358     constant = f16[] constant(0)
1359     reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
1360   })";
1361 
TEST_F(HloVerifierTest,ReduceOperandComputationMismatch)1362 TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
1363   TF_ASSERT_OK_AND_ASSIGN(
1364       auto module,
1365       ParseAndReturnUnverifiedModule(kReduceOperandComputationMismatchHlo));
1366   auto status = verifier().Run(module.get()).status();
1367   ASSERT_FALSE(status.ok());
1368   EXPECT_THAT(status.error_message(),
1369               HasSubstr("Expected instruction to have shape equal to f32[64]"));
1370 }
1371 
TEST_F(HloVerifierTestAllowMixedPrecision,ReduceOperandComputationMismatch)1372 TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
1373   TF_ASSERT_OK_AND_ASSIGN(
1374       auto module,
1375       ParseAndReturnVerifiedModule(kReduceOperandComputationMismatchHlo));
1376   auto status = verifier().Run(module.get()).status();
1377   ASSERT_TRUE(status.ok());
1378 }
1379 
ReplicaGroupsStr(std::vector<std::vector<int64_t>> replica_groups)1380 std::string ReplicaGroupsStr(std::vector<std::vector<int64_t>> replica_groups) {
1381   std::vector<std::string> replica_group_strs;
1382   replica_group_strs.reserve(replica_groups.size());
1383   for (const auto& g : replica_groups) {
1384     replica_group_strs.push_back(
1385         absl::StrFormat("{%s}", absl::StrJoin(g, ",")));
1386   }
1387   return absl::StrFormat("{%s}", absl::StrJoin(replica_group_strs, ", "));
1388 }
1389 
ReplicaCount(const std::vector<std::vector<int64_t>> & replica_groups)1390 int64_t ReplicaCount(const std::vector<std::vector<int64_t>>& replica_groups) {
1391   int64_t replica_count = 0;
1392   for (auto group : replica_groups) {
1393     replica_count += group.size();
1394   }
1395   return replica_count;
1396 }
1397 
MakeCollectiveCommOpComputation(std::vector<std::vector<int64_t>> replica_groups,std::optional<int64_t> replica_count,std::optional<int64_t> num_partitions,absl::string_view other_attributes,absl::string_view template_str)1398 StatusOr<std::unique_ptr<HloModule>> MakeCollectiveCommOpComputation(
1399     std::vector<std::vector<int64_t>> replica_groups,
1400     std::optional<int64_t> replica_count, std::optional<int64_t> num_partitions,
1401     absl::string_view other_attributes, absl::string_view template_str) {
1402   HloModuleConfig config;
1403   config.set_replica_count(
1404       replica_count.value_or(ReplicaCount(replica_groups)));
1405   config.set_num_partitions(num_partitions.value_or(1));
1406   return ParseAndReturnUnverifiedModule(
1407       absl::StrReplaceAll(
1408           template_str,
1409           {{"REPLICA_GROUPS", ReplicaGroupsStr(replica_groups)},
1410            {"OTHER_ATTRIBUTES", other_attributes.empty()
1411                                     ? ""
1412                                     : absl::StrCat(",", other_attributes)}}),
1413       config);
1414 }
1415 
MakeAllReduceComputation(std::vector<std::vector<int64_t>> replica_groups,std::optional<int64_t> replica_count=std::nullopt,std::optional<int64_t> num_partitions=std::nullopt,absl::string_view other_attributes="")1416 StatusOr<std::unique_ptr<HloModule>> MakeAllReduceComputation(
1417     std::vector<std::vector<int64_t>> replica_groups,
1418     std::optional<int64_t> replica_count = std::nullopt,
1419     std::optional<int64_t> num_partitions = std::nullopt,
1420     absl::string_view other_attributes = "") {
1421   const char* kTemplate = R"(
1422   HloModule test
1423   add {
1424     x = f32[] parameter(0)
1425     y = f32[] parameter(1)
1426     ROOT add = f32[] add(x, y)
1427   }
1428   ENTRY entry {
1429     p = f32[128]{0} parameter(0)
1430     crs = f32[128]{0} all-reduce(p), to_apply=add, replica_groups=REPLICA_GROUPS
1431                                      OTHER_ATTRIBUTES
1432   })";
1433   return MakeCollectiveCommOpComputation(replica_groups, replica_count,
1434                                          num_partitions, other_attributes,
1435                                          kTemplate);
1436 }
1437 
TEST_F(HloVerifierTest,AllReduce_NoReplicaGroupsOK)1438 TEST_F(HloVerifierTest, AllReduce_NoReplicaGroupsOK) {
1439   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({}));
1440   TF_ASSERT_OK(verifier().Run(module.get()).status());
1441 }
1442 
TEST_F(HloVerifierTest,AllReduce_DifferentGroupSizesOk)1443 TEST_F(HloVerifierTest, AllReduce_DifferentGroupSizesOk) {
1444   TF_ASSERT_OK_AND_ASSIGN(auto module,
1445                           MakeAllReduceComputation({{0}, {1, 3}, {2}}));
1446   TF_ASSERT_OK(verifier().Run(module.get()).status());
1447 }
1448 
TEST_F(HloVerifierTest,AllReduce_EmptyReplicaGroup)1449 TEST_F(HloVerifierTest, AllReduce_EmptyReplicaGroup) {
1450   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0}, {}}));
1451   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1452               HasSubstr("empty replica group"));
1453 }
1454 
TEST_F(HloVerifierTest,AllReduce_RepeatedReplicaId)1455 TEST_F(HloVerifierTest, AllReduce_RepeatedReplicaId) {
1456   TF_ASSERT_OK_AND_ASSIGN(auto module,
1457                           MakeAllReduceComputation({{0, 1}, {2, 3}, {4, 0}}));
1458   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1459               HasSubstr("Replica 0 is repeated"));
1460 }
1461 
TEST_F(HloVerifierTest,AllReduce_MissingReplicaId)1462 TEST_F(HloVerifierTest, AllReduce_MissingReplicaId) {
1463   TF_ASSERT_OK_AND_ASSIGN(auto module,
1464                           MakeAllReduceComputation({{0, 1}, {2, 3}, {5, 6}}));
1465   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1466               HasSubstr("Replica 4 is not named"));
1467 }
1468 
TEST_F(HloVerifierTest,AllReduce_NotEnougReplicasInGroupConfig)1469 TEST_F(HloVerifierTest, AllReduce_NotEnougReplicasInGroupConfig) {
1470   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllReduceComputation({{0, 1}}, 8));
1471   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1472               HasSubstr("In kCrossReplica mode, replica groups should contain "
1473                         "8 replicas, but found 2"));
1474 }
1475 
TEST_F(HloVerifierTest,AllReduce_TooManyReplicasInGroupConfig)1476 TEST_F(HloVerifierTest, AllReduce_TooManyReplicasInGroupConfig) {
1477   TF_ASSERT_OK_AND_ASSIGN(auto module,
1478                           MakeAllReduceComputation({{0, 1}, {2, 3}}, 2));
1479   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1480               HasSubstr("In kCrossReplica mode, replica groups should contain "
1481                         "2 replicas, but found 4"));
1482 }
1483 
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Invalid)1484 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Invalid) {
1485   TF_ASSERT_OK_AND_ASSIGN(
1486       auto module,
1487       MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 1, "channel_id=1"));
1488   EXPECT_THAT(
1489       verifier().Run(module.get()).status().error_message(),
1490       HasSubstr(
1491           "In kCrossReplicaAndPartition mode, replica groups should contain "
1492           "2 replicas, but found 4"));
1493 }
1494 
TEST_F(HloVerifierTest,AllReduce_CrossReplicaAndPartition_Valid)1495 TEST_F(HloVerifierTest, AllReduce_CrossReplicaAndPartition_Valid) {
1496   TF_ASSERT_OK_AND_ASSIGN(
1497       auto module,
1498       MakeAllReduceComputation({{0, 1}, {2, 3}}, 4, 1, "channel_id=1"));
1499   TF_ASSERT_OK(verifier().Run(module.get()).status());
1500 }
1501 
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Invalid)1502 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Invalid) {
1503   TF_ASSERT_OK_AND_ASSIGN(
1504       auto module,
1505       MakeAllReduceComputation({{0, 1}, {2, 3}}, 1, 2,
1506                                "channel_id=1, use_global_device_ids=true"));
1507   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1508               HasSubstr("In kFlattenedID mode, replica groups should contain "
1509                         "2 flattened IDs, but found 4"));
1510 }
1511 
TEST_F(HloVerifierTest,AllReduce_FlattenedID_Valid)1512 TEST_F(HloVerifierTest, AllReduce_FlattenedID_Valid) {
1513   TF_ASSERT_OK_AND_ASSIGN(
1514       auto module,
1515       MakeAllReduceComputation({{0, 1}, {2, 3}}, 2, 2,
1516                                "channel_id=1, use_global_device_ids=true"));
1517   TF_ASSERT_OK(verifier().Run(module.get()).status());
1518 }
1519 
TEST_F(HloVerifierTest,AllReduceStartAndDone)1520 TEST_F(HloVerifierTest, AllReduceStartAndDone) {
1521   const char* const kModuleStr = R"(
1522   HloModule test
1523   add {
1524     x = f32[] parameter(0)
1525     y = f32[] parameter(1)
1526     ROOT add = f32[] add(x, y)
1527   }
1528   ENTRY entry {
1529     p0 = f32[2,3] parameter(0)
1530     start = f32[2,3] all-reduce-start(p0), to_apply=add
1531     ROOT done = f32[2,3] all-reduce-done(start)
1532   }
1533   )";
1534   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1535                           ParseAndReturnUnverifiedModule(kModuleStr));
1536 
1537   auto status = verifier().Run(module.get()).status();
1538   ASSERT_TRUE(status.ok());
1539 }
1540 
TEST_F(HloVerifierTest,AllReduceStartAndDoneWrongType)1541 TEST_F(HloVerifierTest, AllReduceStartAndDoneWrongType) {
1542   const char* const kModuleStr = R"(
1543   HloModule test
1544   add {
1545     x = f32[] parameter(0)
1546     y = f32[] parameter(1)
1547     ROOT add = f32[] add(x, y)
1548   }
1549   ENTRY entry {
1550     p0 = f32[2,3] parameter(0)
1551     start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1552     ROOT done = f32[2,3] all-reduce-done(start)
1553   }
1554   )";
1555   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1556                           ParseAndReturnUnverifiedModule(kModuleStr));
1557 
1558   auto status = verifier().Run(module.get()).status();
1559   EXPECT_THAT(status.error_message(),
1560               HasSubstr("Expected instruction to have shape equal to "
1561                         "f32[2,3]"));
1562 }
1563 
TEST_F(HloVerifierTest,AllReduceStartAndMultipleDone)1564 TEST_F(HloVerifierTest, AllReduceStartAndMultipleDone) {
1565   const char* const kModuleStr = R"(
1566   HloModule test
1567   add {
1568     x = f32[] parameter(0)
1569     y = f32[] parameter(1)
1570     ROOT add = f32[] add(x, y)
1571   }
1572   ENTRY entry {
1573     p0 = f32[2,3] parameter(0)
1574     start = (f32[2,3], f32[2,3]) all-reduce-start(p0), to_apply=add
1575     done1 = f32[2,3] all-reduce-done(start)
1576     ROOT done2 = f32[2,3] all-reduce-done(start)
1577   }
1578   )";
1579   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1580                           ParseAndReturnUnverifiedModule(kModuleStr));
1581 
1582   auto status = verifier().Run(module.get()).status();
1583   ASSERT_FALSE(status.ok());
1584   EXPECT_THAT(
1585       status.error_message(),
1586       HasSubstr("all-reduce-start instruction requires one consumer, found 2"));
1587 }
1588 
TEST_F(HloVerifierTest,AllReduceDoneWithoutStart)1589 TEST_F(HloVerifierTest, AllReduceDoneWithoutStart) {
1590   const char* const kModuleStr = R"(
1591   HloModule test
1592   ENTRY entry {
1593     p0 = f32[2,3] parameter(0)
1594     p1 = u32[] parameter(1)
1595     tuple = (f32[2,3], f32[2,3]) tuple(p0, p0, p1, p1)
1596     ROOT done = f32[2,3] all-reduce-done(tuple)
1597   }
1598   )";
1599   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1600                           ParseAndReturnUnverifiedModule(kModuleStr));
1601 
1602   auto status = verifier().Run(module.get()).status();
1603   ASSERT_FALSE(status.ok());
1604   EXPECT_THAT(status.error_message(),
1605               HasSubstr("The operand of a all-reduce-done instruction "
1606                         "needs to be all-reduce-start, found tuple"));
1607 }
1608 
MakeAllToAllComputation(std::vector<std::vector<int64_t>> replica_groups,std::optional<int64_t> replica_count=std::nullopt,std::optional<int64_t> num_partitions=std::nullopt,absl::string_view other_attributes="")1609 StatusOr<std::unique_ptr<HloModule>> MakeAllToAllComputation(
1610     std::vector<std::vector<int64_t>> replica_groups,
1611     std::optional<int64_t> replica_count = std::nullopt,
1612     std::optional<int64_t> num_partitions = std::nullopt,
1613     absl::string_view other_attributes = "") {
1614   const char* kTemplate = R"(
1615   HloModule test
1616   ENTRY entry {
1617     p0 = f32[128]{0} parameter(0)
1618     p1 = f32[128]{0} parameter(1)
1619     a2a = (f32[128], f32[128]) all-to-all(p0, p1), replica_groups=REPLICA_GROUPS
1620                                                    OTHER_ATTRIBUTES
1621   })";
1622   return MakeCollectiveCommOpComputation(replica_groups, replica_count,
1623                                          num_partitions, other_attributes,
1624                                          kTemplate);
1625 }
1626 
TEST_F(HloVerifierTest,AllToAll_NoReplicaGroupsOK)1627 TEST_F(HloVerifierTest, AllToAll_NoReplicaGroupsOK) {
1628   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({}));
1629   TF_ASSERT_OK(verifier().Run(module.get()).status());
1630 }
1631 
TEST_F(HloVerifierTest,AllToAll_EmptyReplicaGroup)1632 TEST_F(HloVerifierTest, AllToAll_EmptyReplicaGroup) {
1633   TF_ASSERT_OK_AND_ASSIGN(auto module, MakeAllToAllComputation({{0, 1}, {}}));
1634   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1635               HasSubstr("cannot have an empty replica group"));
1636 }
1637 
TEST_F(HloVerifierTest,AllToAll_RepeatedReplicaId)1638 TEST_F(HloVerifierTest, AllToAll_RepeatedReplicaId) {
1639   TF_ASSERT_OK_AND_ASSIGN(auto module,
1640                           MakeAllToAllComputation({{0, 1}, {2, 3}, {4, 0}}));
1641   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1642               HasSubstr("Replica 0 is repeated"));
1643 }
1644 
TEST_F(HloVerifierTest,AllToAll_MissingReplicaId)1645 TEST_F(HloVerifierTest, AllToAll_MissingReplicaId) {
1646   TF_ASSERT_OK_AND_ASSIGN(auto module,
1647                           MakeAllToAllComputation({{0, 1}, {2, 3}, {5, 6}}));
1648   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1649               HasSubstr("Replica 4 is not named"));
1650 }
1651 
TEST_F(HloVerifierTest,AllToAll_UniformSizeOfReplicasInGroup)1652 TEST_F(HloVerifierTest, AllToAll_UniformSizeOfReplicasInGroup) {
1653   TF_ASSERT_OK_AND_ASSIGN(auto module,
1654                           MakeAllToAllComputation({{0, 1}, {2}, {3, 4}}));
1655   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1656               HasSubstr("Replica groups expected to be of uniform size"));
1657 }
1658 
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Invalid)1659 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Invalid) {
1660   TF_ASSERT_OK_AND_ASSIGN(
1661       auto module,
1662       MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 2, "channel_id=1"));
1663   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1664               HasSubstr("In kCrossPartition mode, replica groups should "
1665                         "contain 2 partitions, but found 4"));
1666 }
1667 
TEST_F(HloVerifierTest,AllToAll_CrossPartition_Valid)1668 TEST_F(HloVerifierTest, AllToAll_CrossPartition_Valid) {
1669   TF_ASSERT_OK_AND_ASSIGN(
1670       auto module,
1671       MakeAllToAllComputation({{0, 1}, {2, 3}}, 1, 4, "channel_id=1"));
1672   TF_ASSERT_OK(verifier().Run(module.get()).status());
1673 }
1674 
TEST_F(HloVerifierTest,AllToAll_LayoutConstrained)1675 TEST_F(HloVerifierTest, AllToAll_LayoutConstrained) {
1676   const char* const kModuleStr = R"(
1677   HloModule test
1678   ENTRY entry {
1679     p0 = f32[128,4]{0,1} parameter(0)
1680     p1 = f32[128,4]{1,0} parameter(1)
1681     ROOT a2a = (f32[128,4]{0,1}, f32[128,4]{1,0}) all-to-all(p0, p1),
1682       replica_groups={{0,1}}
1683   }
1684   )";
1685   HloModuleConfig config;
1686   config.set_replica_count(2);
1687   TF_ASSERT_OK_AND_ASSIGN(auto module,
1688                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1689   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1690               HasSubstr("HLO all-to-all has operands with different shapes"));
1691 }
1692 
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTwice)1693 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTwice) {
1694   const char* const kModuleStr = R"(
1695   HloModule test
1696   ENTRY entry {
1697     p0 = f32[128] parameter(0)
1698     ROOT permute = f32[128] collective-permute(p0),
1699       source_target_pairs={{0,1}, {0,2}, {1,0}}
1700   }
1701   )";
1702   HloModuleConfig config;
1703   config.set_replica_count(3);
1704   TF_ASSERT_OK_AND_ASSIGN(auto module,
1705                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1706   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1707               HasSubstr("Source 0 appears more than once"));
1708 }
1709 
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTwice)1710 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTwice) {
1711   const char* const kModuleStr = R"(
1712   HloModule test
1713   ENTRY entry {
1714     p0 = f32[128] parameter(0)
1715     ROOT permute = f32[128] collective-permute(p0),
1716       source_target_pairs={{0,2}, {1,2}, {2,0}}
1717   }
1718   )";
1719   TF_ASSERT_OK_AND_ASSIGN(auto module,
1720                           ParseAndReturnUnverifiedModule(kModuleStr));
1721   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1722               HasSubstr("Target 2 appears more than once"));
1723 }
1724 
TEST_F(HloVerifierTest,CollectivePermuteSameSourceTooManyTimes)1725 TEST_F(HloVerifierTest, CollectivePermuteSameSourceTooManyTimes) {
1726   const char* const kModuleStr = R"(
1727   HloModule test
1728   ENTRY entry {
1729     replica_id = u32[] replica-id()
1730     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1731     constant.1 = u32[] constant(1000)
1732     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1733     constant.2 = s32[] constant(0)
1734     constant.3 = s32[] constant(1)
1735     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1736     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1737     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1738     ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{0,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1739   }
1740   )";
1741   TF_ASSERT_OK_AND_ASSIGN(auto module,
1742                           ParseAndReturnUnverifiedModule(kModuleStr));
1743   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1744               HasSubstr("Source 0 appears more than 2 times in instruction's "
1745                         "source-target pairs:"));
1746 }
1747 
TEST_F(HloVerifierTest,CollectivePermuteSameTargetTooManyTimes)1748 TEST_F(HloVerifierTest, CollectivePermuteSameTargetTooManyTimes) {
1749   const char* const kModuleStr = R"(
1750   HloModule test
1751   ENTRY entry {
1752     replica_id = u32[] replica-id()
1753     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1754     constant.1 = u32[] constant(1000)
1755     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1756     constant.2 = s32[] constant(0)
1757     constant.3 = s32[] constant(1)
1758     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1759     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1760     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1761     ROOT collective-permute = u32[2,8,128]{2,1,0:T(2,128)} collective-permute(u32[2,8,128] broadcast.0, u32[2,8,128] broadcast.1, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,3},{1,0}}, slice_sizes={{1,8,128},{1,8,128}}
1762   }
1763   )";
1764   TF_ASSERT_OK_AND_ASSIGN(auto module,
1765                           ParseAndReturnUnverifiedModule(kModuleStr));
1766   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1767               HasSubstr("Target 3 appears more than 2 times in instruction's "
1768                         "source-target pairs:"));
1769 }
1770 
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingSourceTarget)1771 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingSourceTarget) {
1772   const char* const kModuleStr = R"(
1773   HloModule test
1774   ENTRY entry {
1775     replica_id = u32[] replica-id()
1776     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1777     constant.1 = u32[] constant(1000)
1778     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1779     broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1780     constant.2 = s32[] constant(0)
1781     constant.3 = s32[] constant(1)
1782     tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1783     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1784     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1785     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1786     constant.4 = s32[] constant(2)
1787     tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1788     tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1789     tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1790     ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1791   }
1792   )";
1793   TF_ASSERT_OK_AND_ASSIGN(auto module,
1794                           ParseAndReturnUnverifiedModule(kModuleStr));
1795   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1796               HasSubstr("Unmatching input buffers and output buffers"));
1797 }
1798 
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingInputAndInputOffset)1799 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingInputAndInputOffset) {
1800   const char* const kModuleStr = R"(
1801   HloModule test
1802   ENTRY entry {
1803     replica_id = u32[] replica-id()
1804     broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1805     constant.1 = u32[] constant(1000)
1806     broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1807     broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1808     constant.2 = s32[] constant(0)
1809     constant.3 = s32[] constant(1)
1810     tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1811     tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1812     tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1813     tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1814     tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1815     constant.4 = s32[] constant(2)
1816     tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1817     tuple.6 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.5)
1818     tuple.9 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.6)
1819     ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (s32[],s32[],s32[]) tuple.3, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.9), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1820   }
1821   )";
1822   TF_ASSERT_OK_AND_ASSIGN(auto module,
1823                           ParseAndReturnUnverifiedModule(kModuleStr));
1824   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1825               HasSubstr("Unmatching input buffers and input offset."));
1826 }
1827 
TEST_F(HloVerifierTest,CollectivePermuteUnmatchingOutputAndOutputOffset)1828 TEST_F(HloVerifierTest, CollectivePermuteUnmatchingOutputAndOutputOffset) {
1829   const char* const kModuleStr = R"(
1830   HloModule test
1831   ENTRY entry {
1832     replica_id = u32[] replica-id()
1833       broadcast.0 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] replica_id), dimensions={}
1834       constant.1 = u32[] constant(1000)
1835       broadcast.1 = u32[2,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1836       broadcast.2 = u32[4,8,128]{2,1,0:T(2,128)} broadcast(u32[] constant.1), dimensions={}
1837       constant.2 = s32[] constant(0)
1838       constant.3 = s32[] constant(1)
1839       tuple.input = (u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.0, u32[2,8,128]{2,1,0:T(2,128)} broadcast.0)
1840       tuple.output = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple(u32[2,8,128]{2,1,0:T(2,128)} broadcast.1, u32[4,8,128]{2,1,0:T(2,128)} broadcast.2)
1841       tuple.2 = (s32[],s32[],s32[]) tuple(constant.2, constant.2, constant.2)
1842       tuple.3 = (s32[],s32[],s32[]) tuple(constant.3, constant.2, constant.2)
1843       tuple.4 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.3)
1844       constant.4 = s32[] constant(2)
1845       tuple.5 = (s32[],s32[],s32[]) tuple(constant.4, constant.2, constant.2)
1846       tuple.7 = ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple((s32[],s32[],s32[]) tuple.2, (s32[],s32[],s32[]) tuple.2)
1847       tuple.8 = (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple(((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.4, ((s32[],s32[],s32[]), (s32[],s32[],s32[])) tuple.7)
1848       ROOT collective-permute.53 = (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) collective-permute((u32[2,8,128]{2,1,0:T(2,128)}, u32[2,8,128]{2,1,0:T(2,128)}) tuple.input, (u32[2,8,128]{2,1,0:T(2,128)}, u32[4,8,128]{2,1,0:T(2,128)}) tuple.output, (((s32[],s32[],s32[]), (s32[],s32[],s32[])), ((s32[],s32[],s32[]), (s32[],s32[],s32[]))) tuple.8, (s32[],s32[],s32[]) tuple.2), source_target_pairs={{0,1},{1,2},{2,3},{3,0},{0,3},{3,2},{2,1},{1,0}}, slice_sizes={{1,8,128},{1,8,128},{2,8,128},{2,8,128}}
1849   }
1850   )";
1851   TF_ASSERT_OK_AND_ASSIGN(auto module,
1852                           ParseAndReturnUnverifiedModule(kModuleStr));
1853   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1854               HasSubstr("Unmatching output buffers and output offset."));
1855 }
1856 
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaSourceOOR)1857 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaSourceOOR) {
1858   const char* const kModuleStr = R"(
1859   HloModule test
1860   ENTRY entry {
1861     p0 = f32[128] parameter(0)
1862     ROOT permute = f32[128] collective-permute(p0),
1863       source_target_pairs={{5,2}, {1,2}, {2,0}}
1864   }
1865   )";
1866   HloModuleConfig config;
1867   config.set_replica_count(3);
1868   TF_ASSERT_OK_AND_ASSIGN(auto module,
1869                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1870   const std::string error_message =
1871       verifier().Run(module.get()).status().error_message();
1872   EXPECT_THAT(error_message, HasSubstr("Source 5"));
1873   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1874 }
1875 
TEST_F(HloVerifierTest,CollectivePermuteCrossReplicaTargetOOR)1876 TEST_F(HloVerifierTest, CollectivePermuteCrossReplicaTargetOOR) {
1877   const char* const kModuleStr = R"(
1878   HloModule test
1879   ENTRY entry {
1880     p0 = f32[128] parameter(0)
1881     ROOT permute = f32[128] collective-permute(p0),
1882       source_target_pairs={{0,1}, {1,2}, {2,7}}
1883   }
1884   )";
1885   HloModuleConfig config;
1886   config.set_replica_count(3);
1887   TF_ASSERT_OK_AND_ASSIGN(auto module,
1888                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1889   const std::string error_message =
1890       verifier().Run(module.get()).status().error_message();
1891   EXPECT_THAT(error_message, HasSubstr("Target 7"));
1892   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1893 }
1894 
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionSourceOOR)1895 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionSourceOOR) {
1896   const char* const kModuleStr = R"(
1897   HloModule test
1898   ENTRY entry {
1899     p0 = f32[128] parameter(0)
1900     ROOT permute = f32[128] collective-permute(p0),
1901       source_target_pairs={{5,2}, {1,2}, {2,0}}, channel_id=1
1902   }
1903   )";
1904   HloModuleConfig config;
1905   config.set_num_partitions(3);
1906   TF_ASSERT_OK_AND_ASSIGN(auto module,
1907                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1908   const std::string error_message =
1909       verifier().Run(module.get()).status().error_message();
1910   EXPECT_THAT(error_message, HasSubstr("Source 5"));
1911   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1912 }
1913 
TEST_F(HloVerifierTest,CollectivePermuteCrossPartitionTargetOOR)1914 TEST_F(HloVerifierTest, CollectivePermuteCrossPartitionTargetOOR) {
1915   const char* const kModuleStr = R"(
1916   HloModule test
1917   ENTRY entry {
1918     p0 = f32[128] parameter(0)
1919     ROOT permute = f32[128] collective-permute(p0),
1920       source_target_pairs={{0,2}, {1,7}, {2,0}}, channel_id=1
1921   }
1922   )";
1923   HloModuleConfig config;
1924   config.set_num_partitions(3);
1925   TF_ASSERT_OK_AND_ASSIGN(auto module,
1926                           ParseAndReturnUnverifiedModule(kModuleStr, config));
1927   const std::string error_message =
1928       verifier().Run(module.get()).status().error_message();
1929   EXPECT_THAT(error_message, HasSubstr("Target 7"));
1930   EXPECT_THAT(error_message, HasSubstr("must be < 3"));
1931 }
1932 
TEST_F(HloVerifierTest,FusionShapeVerifier)1933 TEST_F(HloVerifierTest, FusionShapeVerifier) {
1934   const char* const kModuleStr = R"(
1935   HloModule test
1936 
1937   fused_computation {
1938     ROOT p0 = f32[10,10] parameter(0)
1939   }
1940 
1941   ENTRY entry {
1942     p0 = f32[10,10] parameter(0)
1943     ROOT out = f32[10] fusion(p0), kind=kInput, calls=fused_computation
1944   }
1945   )";
1946   TF_ASSERT_OK_AND_ASSIGN(auto module,
1947                           ParseAndReturnUnverifiedModule(kModuleStr));
1948   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1949               HasSubstr("Fused computation shape"));
1950 }
1951 
TEST_F(HloVerifierTest,FusionThreadVerifier)1952 TEST_F(HloVerifierTest, FusionThreadVerifier) {
1953   const char* const kModuleStr = R"(
1954   HloModule test
1955 
1956   fused_computation {
1957     ROOT p0 = f32[8,12] parameter(0)
1958   }, execution_thread="parallel_thread"
1959 
1960   ENTRY entry {
1961     p0 = f32[8,12] parameter(0)
1962     ROOT out = f32[8,12] fusion(p0), kind=kInput, calls=fused_computation
1963   }
1964   )";
1965   TF_ASSERT_OK_AND_ASSIGN(auto module,
1966                           ParseAndReturnUnverifiedModule(kModuleStr));
1967   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
1968               HasSubstr("expects parent computation thread name same as called "
1969                         "computation's thread name"));
1970 }
1971 
TEST_F(HloVerifierTest,FusionNestedComputationThreadVerifier)1972 TEST_F(HloVerifierTest, FusionNestedComputationThreadVerifier) {
1973   const char* const kModuleStr = R"(
1974   HloModule test
1975 
1976   add {
1977     lhs = f32[] parameter(0)
1978     rhs = f32[] parameter(1)
1979     ROOT add = f32[] add(lhs, rhs)
1980   }, execution_thread="parallel_thread"
1981 
1982   fused_computation {
1983     p0 = f32[8,12] parameter(0)
1984     p1 = f32[8,12] parameter(1)
1985     crs0 = f32[8,12] all-reduce(p1), replica_groups={}, to_apply=add
1986     ROOT result = add(p0, crs0)
1987   }
1988 
1989   ENTRY entry {
1990     p0 = f32[8,12] parameter(0)
1991     p1 = f32[8,12] parameter(1)
1992     ROOT out = f32[8,12] fusion(p0, p1), kind=kInput, calls=fused_computation
1993   }
1994   )";
1995   TF_ASSERT_OK_AND_ASSIGN(auto module,
1996                           ParseAndReturnUnverifiedModule(kModuleStr));
1997   EXPECT_THAT(
1998       verifier().Run(module.get()).status().error_message(),
1999       HasSubstr("Nested computations expects same computation's thread name"));
2000 }
2001 
TEST_F(HloVerifierTest,AllReduceVerifier)2002 TEST_F(HloVerifierTest, AllReduceVerifier) {
2003   const char* const kModuleStr = R"(
2004   HloModule test
2005 
2006   add {
2007     lhs = f32[] parameter(0)
2008     rhs = f32[] parameter(1)
2009     ROOT add = f32[] add(lhs, rhs)
2010   }
2011 
2012   ENTRY entry {
2013     input = f32[8,12]{0,1} parameter(0)
2014     crs0 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add
2015     crs1 = f32[8,12]{0,1} all-reduce(input), replica_groups={}, to_apply=add,
2016       constrain_layout=true
2017     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(crs0, crs1)
2018   }
2019   )";
2020   TF_ASSERT_OK_AND_ASSIGN(auto module,
2021                           ParseAndReturnUnverifiedModule(kModuleStr));
2022   EXPECT_THAT(
2023       verifier().Run(module.get()).status().error_message(),
2024       HasSubstr("mix of layout constrained and unconstrained AllReduce"));
2025 }
2026 
TEST_F(HloVerifierTest,ChannelVerifier)2027 TEST_F(HloVerifierTest, ChannelVerifier) {
2028   const char* const kModuleStr = R"(
2029   HloModule test
2030 
2031   add {
2032     lhs = f32[] parameter(0)
2033     rhs = f32[] parameter(1)
2034     ROOT add = f32[] add(lhs, rhs)
2035   }
2036 
2037   ENTRY entry {
2038     %input = f32[8,12] parameter(0)
2039     %token0 = token[] after-all()
2040     %send = (f32[8,12], u32[], token[]) send(%input, %token0), channel_id=1
2041     %send-done = token[] send-done(%send), channel_id=1
2042     %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
2043       channel_id=1
2044     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%input, %crs)
2045   }
2046   )";
2047   TF_ASSERT_OK_AND_ASSIGN(auto module,
2048                           ParseAndReturnUnverifiedModule(kModuleStr));
2049   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
2050               HasSubstr("used for different types of channel instructions"));
2051 }
2052 
TEST_F(HloVerifierTest,CollectiveChannelVerifier)2053 TEST_F(HloVerifierTest, CollectiveChannelVerifier) {
2054   const char* const kModuleStr = R"(
2055   HloModule test
2056 
2057   add {
2058     lhs = f32[] parameter(0)
2059     rhs = f32[] parameter(1)
2060     ROOT add = f32[] add(lhs, rhs)
2061   }
2062 
2063   ENTRY entry {
2064     %input = f32[8,12] parameter(0)
2065     %permute = f32[8,12] collective-permute(%input),
2066       source_target_pairs={{0,1},{1,0}}, channel_id=1
2067     %crs = f32[8,12] all-reduce(%input), replica_groups={}, to_apply=add,
2068       channel_id=1
2069     ROOT result = (f32[8,12]{0,1}, f32[8,12]{0,1}) tuple(%permute, %crs)
2070   }
2071   )";
2072   TF_ASSERT_OK_AND_ASSIGN(auto module,
2073                           ParseAndReturnUnverifiedModule(kModuleStr));
2074   EXPECT_THAT(verifier().Run(module.get()).status().error_message(),
2075               HasSubstr("used for different types of channel instructions"));
2076 }
2077 
TEST_F(HloVerifierTestLayoutSensitive,CollectivePermuteStartAndDone)2078 TEST_F(HloVerifierTestLayoutSensitive, CollectivePermuteStartAndDone) {
2079   const char* const kModuleStr = R"(
2080   HloModule Module
2081 
2082   ENTRY CollectivePermuteStartAndDone {
2083     p0 = f32[2,3]{1,0:S(1)} parameter(0)
2084     collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
2085     ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2086   }
2087   )";
2088   TF_ASSERT_OK_AND_ASSIGN(auto module,
2089                           ParseAndReturnUnverifiedModule(kModuleStr));
2090 
2091   auto status = verifier().Run(module.get()).status();
2092   ASSERT_TRUE(status.ok());
2093 }
2094 
TEST_F(HloVerifierTest,CollectivePermuteStartAndDoneWrongType)2095 TEST_F(HloVerifierTest, CollectivePermuteStartAndDoneWrongType) {
2096   const char* const kModuleStr = R"(
2097   HloModule Module
2098 
2099   ENTRY CollectivePermuteStartAndDoneWrongType {
2100     p0 = f32[2,3]{1,0:S(1)} parameter(0)
2101     collective-permute-start.1 = f32[2,3]{1,0:S(1)} collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
2102     ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2103   }
2104   )";
2105   TF_ASSERT_OK_AND_ASSIGN(auto module,
2106                           ParseAndReturnUnverifiedModule(kModuleStr));
2107 
2108   auto status = verifier().Run(module.get()).status();
2109   ASSERT_FALSE(status.ok());
2110   EXPECT_THAT(status.error_message(),
2111               HasSubstr("Expected instruction to have shape equal to "
2112                         "(f32[2,3], f32[2,3], u32[], u32[])"));
2113 }
2114 
TEST_F(HloVerifierTest,CollectivePermuteStartAndMultipleDone)2115 TEST_F(HloVerifierTest, CollectivePermuteStartAndMultipleDone) {
2116   const char* const kModuleStr = R"(
2117   HloModule Module
2118 
2119   ENTRY CollectivePermuteStartAndMultipleDone {
2120     p0 = f32[2,3]{1,0:S(1)} parameter(0)
2121     collective-permute-start.1 = (f32[2,3]{1,0:S(1)}, f32[2,3]{1,0:S(1)}, u32[], u32[]) collective-permute-start(p0), source_target_pairs={{0,1},{1,0}}, channel_id=1
2122     collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2123     ROOT collective-permute-done.2 = f32[2,3]{1,0:S(1)} collective-permute-done(collective-permute-start.1)
2124   }
2125   )";
2126   TF_ASSERT_OK_AND_ASSIGN(auto module,
2127                           ParseAndReturnUnverifiedModule(kModuleStr));
2128 
2129   auto status = verifier().Run(module.get()).status();
2130   ASSERT_FALSE(status.ok());
2131   EXPECT_THAT(
2132       status.error_message(),
2133       HasSubstr("collective-permute-start instruction requires one consumer, "
2134                 "found 2"));
2135 }
2136 
TEST_F(HloVerifierTest,CollectivePermuteDoneNoCollectivePermuteStart)2137 TEST_F(HloVerifierTest, CollectivePermuteDoneNoCollectivePermuteStart) {
2138   const char* const kModuleStr = R"(
2139   HloModule Module
2140 
2141   ENTRY CollectivePermuteDoneNoCollectivePermuteStart {
2142     p0 = f32[2,3]{1,0:S(1)} parameter(0)
2143     p1 = f32[2,3]{1,0:S(1)} parameter(1)
2144     p2 = u32[] parameter(2)
2145     p3 = u32[] parameter(3)
2146     tuple.1 = (f32[2,3], f32[2,3], u32[], u32[]) tuple(p0, p1, p2, p3)
2147     ROOT collective-permute-done.1 = f32[2,3]{1,0:S(1)} collective-permute-done(tuple.1)
2148   }
2149   )";
2150   TF_ASSERT_OK_AND_ASSIGN(auto module,
2151                           ParseAndReturnUnverifiedModule(kModuleStr));
2152 
2153   auto status = verifier().Run(module.get()).status();
2154   ASSERT_FALSE(status.ok());
2155   EXPECT_THAT(status.error_message(),
2156               HasSubstr("The operand of a collective-permute-done instruction "
2157                         "needs to be collective-permute-start, found tuple"));
2158 }
2159 
TEST_F(HloVerifierTest,ComparisonTypeFloat)2160 TEST_F(HloVerifierTest, ComparisonTypeFloat) {
2161   const char* const hlo_string = R"(
2162   HloModule Module
2163 
2164   ENTRY RngOperandElementTypesNotMatch {
2165    p0 = f32[] parameter(0)
2166    ROOT cmp = pred[] compare(f32[] p0, f32[] p0), direction=LT, type=UNSIGNED
2167   }
2168   )";
2169   TF_ASSERT_OK_AND_ASSIGN(auto module,
2170                           ParseAndReturnUnverifiedModule(hlo_string));
2171 
2172   auto status = verifier().Run(module.get()).status();
2173   ASSERT_FALSE(status.ok());
2174   EXPECT_THAT(status.error_message(),
2175               HasSubstr("Expected comparison type FLOAT or TOTALORDER"));
2176 }
2177 
TEST_F(HloVerifierTest,ComparisonTypeSigned)2178 TEST_F(HloVerifierTest, ComparisonTypeSigned) {
2179   const char* const hlo_string = R"(
2180   HloModule Module
2181 
2182   ENTRY RngOperandElementTypesNotMatch {
2183    p0 = s32[] parameter(0)
2184    ROOT cmp = pred[] compare(s32[] p0, s32[] p0), direction=LT, type=UNSIGNED
2185   }
2186   )";
2187   TF_ASSERT_OK_AND_ASSIGN(auto module,
2188                           ParseAndReturnUnverifiedModule(hlo_string));
2189 
2190   auto status = verifier().Run(module.get()).status();
2191   ASSERT_FALSE(status.ok());
2192   EXPECT_THAT(status.error_message(),
2193               HasSubstr("Expected comparison type SIGNED"));
2194 }
2195 
TEST_F(HloVerifierTest,ComparisonTypeUnsigned)2196 TEST_F(HloVerifierTest, ComparisonTypeUnsigned) {
2197   const char* const hlo_string = R"(
2198   HloModule Module
2199 
2200   ENTRY RngOperandElementTypesNotMatch {
2201    p0 = u32[] parameter(0)
2202    ROOT cmp = pred[] compare(u32[] p0, u32[] p0), direction=LT, type=SIGNED
2203   }
2204   )";
2205   TF_ASSERT_OK_AND_ASSIGN(auto module,
2206                           ParseAndReturnUnverifiedModule(hlo_string));
2207 
2208   auto status = verifier().Run(module.get()).status();
2209   ASSERT_FALSE(status.ok());
2210   EXPECT_THAT(status.error_message(),
2211               HasSubstr("Expected comparison type UNSIGNED"));
2212 }
2213 
TEST_F(HloVerifierTest,ComparisonTypePred)2214 TEST_F(HloVerifierTest, ComparisonTypePred) {
2215   const char* const hlo_string = R"(
2216   HloModule Module
2217 
2218   ENTRY RngOperandElementTypesNotMatch {
2219    p0 = pred[] parameter(0)
2220    ROOT cmp = pred[] compare(pred[] p0, pred[] p0), direction=LT, type=SIGNED
2221   }
2222   )";
2223   TF_ASSERT_OK_AND_ASSIGN(auto module,
2224                           ParseAndReturnUnverifiedModule(hlo_string));
2225 
2226   auto status = verifier().Run(module.get()).status();
2227   ASSERT_FALSE(status.ok());
2228   EXPECT_THAT(status.error_message(),
2229               HasSubstr("Expected comparison type UNSIGNED"));
2230 }
2231 
TEST_F(HloVerifierTest,UseGlobalDeviceIdsEmptyReplicaGroup)2232 TEST_F(HloVerifierTest, UseGlobalDeviceIdsEmptyReplicaGroup) {
2233   const char* const hlo_string = R"(
2234   HloModule Module
2235   add {
2236     lhs = f32[] parameter(0)
2237     rhs = f32[] parameter(1)
2238     ROOT add = f32[] add(lhs, rhs)
2239   }
2240 
2241   ENTRY CRS {
2242     input = f32[8]{0} parameter(0)
2243     ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, channel_id=1,
2244                          use_global_device_ids=true, to_apply=add
2245   })";
2246   TF_ASSERT_OK_AND_ASSIGN(auto module,
2247                           ParseAndReturnUnverifiedModule(hlo_string));
2248 
2249   auto status = verifier().Run(module.get()).status();
2250   ASSERT_FALSE(status.ok());
2251   EXPECT_THAT(
2252       status.error_message(),
2253       HasSubstr("Replica groups must be specified in flattened-id mode"));
2254 }
2255 
TEST_F(HloVerifierTest,InvalidChannelIDandUseGlobalDeviceIDs)2256 TEST_F(HloVerifierTest, InvalidChannelIDandUseGlobalDeviceIDs) {
2257   const char* const hlo_string = R"(
2258   HloModule Module
2259   add {
2260     lhs = f32[] parameter(0)
2261     rhs = f32[] parameter(1)
2262     ROOT add = f32[] add(lhs, rhs)
2263   }
2264 
2265   ENTRY CRS {
2266     input = f32[8]{0} parameter(0)
2267     ROOT crs = f32[8]{0} all-reduce(input), replica_groups={},
2268                          use_global_device_ids=true, to_apply=add
2269   })";
2270   TF_ASSERT_OK_AND_ASSIGN(auto module,
2271                           ParseAndReturnUnverifiedModule(hlo_string));
2272 
2273   auto status = verifier().Run(module.get()).status();
2274   ASSERT_FALSE(status.ok());
2275   EXPECT_THAT(
2276       status.error_message(),
2277       HasSubstr(
2278           "Invalid combination of has_channel_id and use_global_device_ids"));
2279 }
2280 
TEST_F(HloVerifierTest,ReduceScatterInvalidOutputSize0)2281 TEST_F(HloVerifierTest, ReduceScatterInvalidOutputSize0) {
2282   const char* const hlo_string = R"(
2283   HloModule Module
2284   add {
2285     lhs = f32[] parameter(0)
2286     rhs = f32[] parameter(1)
2287     ROOT add = f32[] add(lhs, rhs)
2288   }
2289 
2290   ENTRY CRS {
2291     input = f32[8]{0} parameter(0)
2292     ROOT crs = f32[8]{0} reduce-scatter(input), replica_groups={{0,1}},
2293                          to_apply=add, dimensions={0}
2294   })";
2295   TF_ASSERT_OK_AND_ASSIGN(auto module,
2296                           ParseAndReturnUnverifiedModule(hlo_string));
2297 
2298   auto status = verifier().Run(module.get()).status();
2299   ASSERT_FALSE(status.ok());
2300   EXPECT_THAT(status.error_message(),
2301               HasSubstr("shard_count = 1, subgroup_size = 2"));
2302 }
2303 
TEST_F(HloVerifierTest,ReduceScatterInvalidScatterDim)2304 TEST_F(HloVerifierTest, ReduceScatterInvalidScatterDim) {
2305   const char* const hlo_string = R"(
2306   HloModule Module
2307   add {
2308     lhs = f32[] parameter(0)
2309     rhs = f32[] parameter(1)
2310     ROOT add = f32[] add(lhs, rhs)
2311   }
2312 
2313   ENTRY CRS {
2314     input = f32[8]{0} parameter(0)
2315     ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}},
2316                          to_apply=add, dimensions={1}
2317   })";
2318   TF_ASSERT_OK_AND_ASSIGN(auto module,
2319                           ParseAndReturnUnverifiedModule(hlo_string));
2320 
2321   auto status = verifier().Run(module.get()).status();
2322   ASSERT_FALSE(status.ok());
2323   EXPECT_THAT(
2324       status.error_message(),
2325       HasSubstr("ars->scatter_dimension() < ars->operand(i)->shape().rank()"));
2326 }
2327 
TEST_F(HloVerifierTest,ReduceScatterNonUniformGroups)2328 TEST_F(HloVerifierTest, ReduceScatterNonUniformGroups) {
2329   const char* const hlo_string = R"(
2330   HloModule Module
2331   add {
2332     lhs = f32[] parameter(0)
2333     rhs = f32[] parameter(1)
2334     ROOT add = f32[] add(lhs, rhs)
2335   }
2336 
2337   ENTRY CRS {
2338     input = f32[8]{0} parameter(0)
2339     ROOT crs = f32[4]{0} reduce-scatter(input), replica_groups={{0,1}, {2,3,4}},
2340                          to_apply=add, dimensions={0}
2341   })";
2342   TF_ASSERT_OK_AND_ASSIGN(auto module,
2343                           ParseAndReturnUnverifiedModule(hlo_string));
2344 
2345   auto status = verifier().Run(module.get()).status();
2346   ASSERT_FALSE(status.ok());
2347   EXPECT_THAT(status.error_message(),
2348               HasSubstr("Replica groups expected to be of uniform size"));
2349 }
2350 
TEST_F(HloVerifierTest,VerifyBroadcastDimensionsOrder)2351 TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrder) {
2352   const char* const hlo = R"(
2353 HloModule module
2354 
2355 ENTRY computation {
2356   mul = f32[32,32,32]{2,1,0} parameter(0)
2357   ROOT broadcast = f32[32,32,32,32]{3,2,1,0} broadcast(mul), dimensions={3,2,1}
2358 }
2359 )";
2360 
2361   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2362   auto status = HloVerifier{HloVerifierOpts{}.VerifyBroadcastDimensionsOrder()}
2363                     .Run(module.get())
2364                     .status();
2365   ASSERT_FALSE(status.ok());
2366   EXPECT_THAT(status.error_message(),
2367               HasSubstr("Broadcast dimensions should be ordered"));
2368 }
2369 
TEST_F(HloVerifierTest,VerifyBroadcastDimensionsOrderOK)2370 TEST_F(HloVerifierTest, VerifyBroadcastDimensionsOrderOK) {
2371   const char* const hlo = R"(
2372 HloModule module
2373 
2374 ENTRY computation {
2375   mul = f32[4,5] parameter(0)
2376   ROOT broadcast = f32[4,3,2,5] broadcast(mul), dimensions={0,3}
2377 }
2378 )";
2379 
2380   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2381   TF_ASSERT_OK(HloVerifier{HloVerifierOpts{}.VerifyBroadcastDimensionsOrder()}
2382                    .Run(module.get())
2383                    .status());
2384 }
2385 
TEST_F(HloVerifierTest,ReshapeIsNotBitcast)2386 TEST_F(HloVerifierTest, ReshapeIsNotBitcast) {
2387   const char* const hlo = R"(
2388 HloModule Module
2389 
2390 ENTRY main {
2391   p = f32[8,3]{1,0} parameter(0)
2392   ROOT r = f32[4,2,3]{0,1,2} reshape(p)
2393 }
2394 )";
2395 
2396   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2397   auto status =
2398       HloVerifier{
2399           HloVerifierOpts{}.MakeLayoutSensitive().VerifyReshapeIsBitcast()}
2400           .Run(module.get())
2401           .status();
2402   ASSERT_FALSE(status.ok());
2403   EXPECT_THAT(status.error_message(),
2404               HasSubstr("Reshape should be a physical bitcast"));
2405 }
2406 
TEST_F(HloVerifierTest,ReshapeIsBitcast)2407 TEST_F(HloVerifierTest, ReshapeIsBitcast) {
2408   const char* const hlo = R"(
2409 HloModule Module
2410 
2411 ENTRY main {
2412   p = f32[8]{0} parameter(0)
2413   ROOT r = f32[4,2]{1,0} reshape(p)
2414 }
2415 )";
2416 
2417   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2418   TF_ASSERT_OK(HloVerifier{
2419       HloVerifierOpts{}.MakeLayoutSensitive().VerifyReshapeIsBitcast()}
2420                    .Run(module.get())
2421                    .status());
2422 }
2423 
TEST_F(HloVerifierTest,VerifyCustomCallThread)2424 TEST_F(HloVerifierTest, VerifyCustomCallThread) {
2425   const char* const hlo = R"(
2426     HloModule module
2427     %call_body (prev.2: s32[]) -> pred[] {
2428       %constant.1 = s32[] constant(5)
2429       %prev.2 = s32[] parameter(0)
2430       ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
2431     }, execution_thread="parallel_thread"
2432 
2433     ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
2434       %constant.2 = s32[] constant(0)
2435       ROOT %custom = s32[] custom-call(s32[] %constant.2), custom_call_target="MyCustomCall", to_apply=%call_body
2436     }
2437 )";
2438 
2439   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(hlo));
2440   auto status =
2441       HloVerifier{
2442           HloVerifierOpts{}.VerifyCustomCallNestedComputationThreadName()}
2443           .Run(module.get())
2444           .status();
2445   ASSERT_FALSE(status.ok());
2446   EXPECT_THAT(status.error_message(),
2447               HasSubstr("expects parent computation thread name same as called "
2448                         "computation's thread name"));
2449 }
2450 
TEST_F(HloVerifierTest,CheckWhileThread)2451 TEST_F(HloVerifierTest, CheckWhileThread) {
2452   const char* const hlo_string = R"(
2453     HloModule While, entry_computation_layout={()->s32[]}
2454 
2455     %body.v3 (prev.1: s32[]) -> s32[] {
2456       %constant = s32[] constant(1)
2457       %prev.1 = s32[] parameter(0)
2458       ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
2459     }
2460 
2461     %condition.v3 (prev.2: s32[]) -> pred[] {
2462       %constant.1 = s32[] constant(5)
2463       %prev.2 = s32[] parameter(0)
2464       ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
2465     }, execution_thread="parallel_thread"
2466 
2467     ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
2468       %constant.2 = s32[] constant(0)
2469       ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
2470     }
2471   )";
2472   TF_ASSERT_OK_AND_ASSIGN(auto module,
2473                           ParseAndReturnUnverifiedModule(hlo_string));
2474   auto status = verifier().Run(module.get()).status();
2475   ASSERT_FALSE(status.ok());
2476   EXPECT_THAT(status.error_message(),
2477               HasSubstr("expects parent computation thread name same as called "
2478                         "computation's thread name"));
2479 }
2480 
TEST_F(HloVerifierTest,CheckWhileContainsAsyncThread)2481 TEST_F(HloVerifierTest, CheckWhileContainsAsyncThread) {
2482   const char* const hlo_string = R"(
2483     HloModule While, entry_computation_layout={()->s32[]}
2484 
2485     %async_add (prev.1: s32[]) -> s32[] {
2486       %constant = s32[] constant(1)
2487       %prev.1 = s32[] parameter(0)
2488       ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
2489     }, execution_thread="parallel_thread"
2490 
2491     %body.v3 (prev.1: s32[]) -> s32[] {
2492       %constant = s32[] constant(1)
2493       %prev.1 = s32[] parameter(0)
2494       ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
2495     }
2496 
2497     %condition.v3 (prev.2: s32[]) -> pred[] {
2498       %constant.1 = s32[] constant(5)
2499       %prev.2 = s32[] parameter(0)
2500       %async-start = ((s32[]), s32[], s32[]) custom-call-start(s32[] %prev.2), async_execution_thread="parallel_thread", custom_call_target="async_add"
2501       %async-done = s32[] custom-call-done(((s32[]), s32[], s32[]) %async-start), async_execution_thread="parallel_thread", custom_call_target="async_add"
2502       ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %async-done), direction=GT
2503     }
2504 
2505     ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
2506       %constant.2 = s32[] constant(0)
2507       ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
2508     }
2509   )";
2510   TF_ASSERT_OK_AND_ASSIGN(auto module,
2511                           ParseAndReturnUnverifiedModule(hlo_string));
2512   auto status = verifier().Run(module.get()).status();
2513   ASSERT_TRUE(status.ok());
2514 }
2515 
TEST_F(HloVerifierTestLayoutFusion,DynamicUpdateSliceWithMemorySpace)2516 TEST_F(HloVerifierTestLayoutFusion, DynamicUpdateSliceWithMemorySpace) {
2517   const char* const hlo_string = R"(
2518 HloModule fusion, is_scheduled=true
2519 
2520 fused_computation {
2521   %parameter.0 = bf16[1,8,1,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(0)
2522   %parameter.1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(1)
2523   %c = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)} copy(parameter.1)
2524   %constant.1 = s32[] constant(0)
2525   ROOT %dynamic-update-slice.1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)}
2526     dynamic-update-slice(%c, %parameter.0, %constant.1, %constant.1,
2527     %constant.1, %constant.1, %constant.1)
2528 }
2529 
2530 ENTRY entry (parameter.0: bf16[1,8,1,8,320], parameter.1: bf16[1,8,6,8,320]) -> bf16[1,8,6,8,320]{
2531   %p0 = bf16[1,8,1,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(0)
2532   %p1 = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} parameter(1)
2533   ROOT out = bf16[1,8,6,8,320]{4,0,3,2,1:T(2,128)(2,1)S(3)} fusion(p0, p1), kind=kLoop, calls=fused_computation
2534 })";
2535   TF_ASSERT_OK_AND_ASSIGN(auto module,
2536                           ParseAndReturnUnverifiedModule(hlo_string));
2537 
2538   auto status = verifier().Run(module.get()).status();
2539   ASSERT_TRUE(status.ok());
2540 }
2541 
2542 }  // namespace
2543 }  // namespace xla
2544