• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35 
36 namespace xla {
37 namespace {
38 
39 namespace m = ::xla::match;
40 using absl::string_view;
41 
42 struct TestData {
43   string test_name;
44   string module_string;
45   bool enable_verification = true;
46 };
47 
TestDataToString(const::testing::TestParamInfo<TestData> & data)48 string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
49   return data.param.test_name;
50 }
51 
52 // For each string below, we check that:
53 //  - we parse it to an HloModule successfully, and
54 //  - the stringification of the resulting HloModule is equal to our original
55 //    string.
CreateTestCases()56 std::vector<TestData> CreateTestCases() {
57   // clang-format off
58   return std::vector<TestData>({
59 // ax + y
60 {
61 "AxpyParam",
62 R"(HloModule axpy_module
63 
64 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
65   %alpha = f32[] parameter(0)
66   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
67   %x = f32[2,4]{1,0} parameter(1)
68   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
69   %y = f32[2,4]{1,0} parameter(2)
70   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
71 }
72 
73 )"
74 },
75 // parameter replication
76 {
77 "ParamReplication",
78 R"(HloModule param_replication_module
79 
80 ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
81   %a = f32[] parameter(0), parameter_replication={true}
82   %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
83   ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
84 }
85 
86 )"
87 },
88 // pred constant
89 {
90 "ConstantPred",
91 R"(HloModule constant_pred_module
92 
93 ENTRY %constant_pred () -> pred[] {
94   ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
95 }
96 
97 )"
98 },
99 // pred array constant
100 {
101 "ConstantPredArray",
102 R"(HloModule module
103 
104 ENTRY %constant_pred_array () -> pred[2,3] {
105   ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
106 }
107 
108 )"
109 },
110 
111 // s32 constant
112 {
113 "ConstantS32",
114 R"(HloModule constant_s32_module
115 
116 ENTRY %constant_s32 () -> s32[] {
117   ROOT %constant = s32[] constant(-42)
118 }
119 
120 )"
121 },
122 // f32 constant, but the value is not a decimal and there is a backend
123 // configuration
124 {
125 "ConstantF32",
126 R"(HloModule ConstantF32_module
127 
128 ENTRY %ConstantF32.v4 () -> f32[] {
129   ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
130 }
131 
132 )"
133 },
134 // f32 constant, rank 1 empty array.
135 {
136 "ConstantF32R1Empty",
137 R"(HloModule ConstantF32Empty_module
138 
139 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
140   ROOT %constant = f32[0]{0} constant({})
141 }
142 
143 )"
144 },
145 // f32 constant, rank 4 empty array.
146 {
147 "ConstantF32R4Empty",
148 R"(HloModule ConstantF32R4Empty_module
149 
150 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
151   ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
152 }
153 
154 )"
155 },
156 // constant 4D
157 {
158 "Constant4D",
159 R"(HloModule Small_3x2x1x1_module
160 
161 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
162   ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
163 }
164 
165 )"
166 },
167 // non-finite constants: nan, inf, -inf
168 {
169 "ConstantNonFinite",
170 R"(HloModule IsFiniteR1F32s_module
171 
172 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
173   %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
174   ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
175 }
176 
177 )"
178 },
179 // constant f16
180 {
181 "ConstantF16",
182 R"(HloModule ConstantF16_module
183 
184 ENTRY %ConstantF16.v4 () -> f16[] {
185   ROOT %constant = f16[] constant(500)
186 }
187 
188 )"
189 },
190 // bf16
191 {
192 "BF16",
193 R"(HloModule BF16
194 
195 ENTRY %BF16.v4 () -> bf16[] {
196   ROOT %constant = bf16[] constant(500)
197 }
198 
199 )"
200 },
201 // constant + constant
202 {
203 "AddConstants",
204 R"(HloModule add_constants_module
205 
206 ENTRY %add_constants () -> f32[] {
207   %constant = f32[] constant(3.14)
208   ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
209 }
210 
211 )"
212 },
213 // tuple constant
214 {
215 "TupleConstant",
216 R"(HloModule TupleConstant_module
217 
218 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
219   ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
220 }
221 
222 )"
223 },
224 // v1 > v2 ? v1 : v2
225 {
226 "SelectR1F32",
227 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
228 
229 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
230   %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
231   %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
232   %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
233   ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
234 }
235 
236 )"
237 },
238 // empty tuple
239 {
240 "EmptyTupleCreate",
241 R"(HloModule EmptyTupleCreate_module
242 
243 ENTRY %EmptyTupleCreate.v1 () -> () {
244   ROOT %tuple = () tuple()
245 }
246 
247 )"
248 },
249 // tuple
250 {
251 "TupleCreate",
252 R"(HloModule TupleCreate_module
253 
254 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
255   %v1 = f32[] parameter(0)
256   %v2 = f32[3]{0} parameter(1)
257   %v3 = f32[2,3]{1,0} parameter(2)
258   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
259 }
260 
261 )"
262 },
263 {
264 "ShardedTupleCreate",
265 R"(HloModule ShardedTupleCreate_module
266 
267 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
268   %v1 = f32[] parameter(0)
269   %v2 = f32[3]{0} parameter(1)
270   %v3 = f32[2,3]{1,0} parameter(2)
271   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
272 }
273 
274 )"
275 },
276 {
277 "DomainParsing",
278 R"(HloModule DomainParsing_module
279 
280 ENTRY %DomainParsing (v1: f32[]) -> f32[] {
281   %v1 = f32[] parameter(0)
282   ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
283 }
284 
285 )"
286 },
287 // int32 result = 0;
288 // while (result < 5) { result = result + 1; }
289 {
290 "WhileWithScalarS32Result",
291 R"(HloModule WhileWithScalarS32Result_module
292 
293 %body.v3 (prev.1: s32[]) -> s32[] {
294   %constant = s32[] constant(1)
295   %prev.1 = s32[] parameter(0)
296   ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
297 }
298 
299 %condition.v3 (prev.2: s32[]) -> pred[] {
300   %constant.1 = s32[] constant(5)
301   %prev.2 = s32[] parameter(0)
302   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
303 }
304 
305 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
306   %constant.2 = s32[] constant(0)
307   ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
308 }
309 
310 )"
311 },
312 // copy-start and copy-done
313 {
314 "CopyStartAndCopyDone",
315 
316 R"(HloModule CopyStartAndCopyDone_module
317 
318 ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
319   %v1 = f32[] parameter(0)
320   %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
321   %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
322   %v2 = f32[2,3]{1,0:S(1)} parameter(1)
323   %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
324   %copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
325   ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
326 }
327 
328 )"
329 },
330 // send and recv
331 {
332 "SendRecv",
333 R"(HloModule TwoSendRecvBothWayRecvFist_module
334 
335 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
336   %token0 = token[] after-all()
337   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
338   ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
339   %constant = f32[] constant(2.1), sharding={maximal device=0}
340   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
341   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
342 }
343 
344 )"
345 },
346 {
347 "SendRecvWithHostTransfer",
348 R"(HloModule HostTransferSendRecv_module
349 
350 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
351   %token0 = token[] after-all()
352   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
353   ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
354   %constant = f32[] constant(2.1), sharding={maximal device=0}
355   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
356   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
357 }
358 
359 )"
360 },
361 // get-tuple-element
362 {
363 "GetTupleElement",
364 R"(HloModule GetTupleElement_module
365 
366 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
367   %constant = f32[3]{0} constant({1, 2, 3})
368   %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
369   %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
370   ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
371 }
372 
373 )"
374 },
375 // call
376 {
377 "Call",
378 R"(HloModule CallR0F32IdentityScalar_module
379 
380 %Identity.v1 (x: f32[]) -> f32[] {
381   ROOT %x = f32[] parameter(0)
382 }
383 
384 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
385   %constant = f32[] constant(42)
386   ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
387 }
388 
389 )"
390 },
391 // CustomCall with backend_config.
392 {
393 "CustomCallWithOpaque",
394 R"(HloModule custom_call
395 
396 ENTRY %CustomCall () -> f32[1,2,3] {
397   %constant = f32[1]{0} constant({12345})
398   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
399 }
400 
401 )"
402 },
403 // reduce window
404 {
405 "ReduceWindow",
406 R"(HloModule R4UnitWindow_module
407 
408 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
409   %lhs = f32[] parameter(0)
410   %rhs = f32[] parameter(1)
411   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
412 }
413 
414 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
415   %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
416   %constant = f32[] constant(0)
417   ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
418 }
419 
420 )"
421 },
422 // reduce window on scalar
423 {
424 "ReduceWindowScalar",
425 R"(HloModule reduce_window_scalar
426 
427 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
428   %lhs = f32[] parameter(0)
429   %rhs = f32[] parameter(1)
430   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
431 }
432 
433 ENTRY %R4UnitWindowScalar () -> f32[] {
434   %constant = f32[] constant(42)
435   %constant.1 = f32[] constant(1)
436   ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
437 }
438 
439 )"
440 },
441 // convolution
442 {
443 "Convolution",
444 R"(HloModule Convolve1D1Window_0_module
445 
446 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
447   %input = f32[1,2,1]{2,1,0} parameter(0)
448   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
449   %filter = f32[1,1,1]{2,1,0} parameter(1)
450   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
451 }
452 
453 )"
454 },
455 // convolution rank 2
456 {
457 "ConvolutionR2",
458 R"(HloModule ConvolveR2_module
459 
460 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
461   %input = f32[1,2]{1,0} parameter(0)
462   %filter = f32[2,2]{1,0} parameter(1)
463   ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
464 }
465 
466 )"
467 },
468 // convolution backward
469 {
470 "ConvolutionBackward",
471 R"(HloModule ConvolveBackward_module
472 
473 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
474   %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
475   %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
476   ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
477 }
478 
479 )"
480 },
481 // reverse(constant)
482 {
483 "Reverse4D",
484 R"(HloModule Reverse4DFloatArrayOnDim01_module
485 
486 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
487   %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
488   ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
489 }
490 
491 )"
492 },
493 // concat
494 {
495 "Concat",
496 R"(HloModule Concat2x3With2x5_module
497 
498 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
499   %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
500   %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
501   ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
502 }
503 
504 )"
505 },
506 // select and scatter
507 {
508 "SelectAndScatter",
509 R"(HloModule R4F32OverlapSmall_module
510 
511 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
512   %lhs = f32[] parameter(0)
513   %rhs = f32[] parameter(1)
514   ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
515 }
516 
517 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
518   %lhs.1 = f32[] parameter(0)
519   %rhs.1 = f32[] parameter(1)
520   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
521 }
522 
523 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
524   %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
525   %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
526   %constant.2 = f32[] constant(0)
527   ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
528 }
529 
530 )"
531 },
532 // select and scatter on scalar
533 {
534 "SelectAndScatterScalar",
535 R"(HloModule select_and_scatter_scalar
536 
537 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
538   %lhs = f32[] parameter(0)
539   %rhs = f32[] parameter(1)
540   ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
541 }
542 
543 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
544   %lhs.1 = f32[] parameter(0)
545   %rhs.1 = f32[] parameter(1)
546   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
547 }
548 
549 ENTRY %SelectAndScatterScalar () -> f32[] {
550   %constant = f32[] constant(42)
551   %constant.1 = f32[] constant(1)
552   %constant.2 = f32[] constant(2)
553   ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
554 }
555 
556 )"
557 },
558 // slice
559 {
560 "Slice",
561 R"(HloModule slice_module
562 
563 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
564   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
565   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
566 }
567 
568 )"
569 },
570 // slice, no stride
571 {
572 "SliceNoStride",
573 R"(HloModule Slice3x3x3_To_1x3x3_F32_module
574 
575 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
576   %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
577   ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
578 }
579 
580 )"
581 },
582 // slice R0
583 {
584 "SliceR0",
585 R"(HloModule SliceR0_module
586 
587 ENTRY %SliceR0.v2 () -> s32[] {
588   %constant = s32[] constant(1)
589   ROOT %slice = s32[] slice(s32[] %constant), slice={}
590 }
591 
592 )"
593 },
594 // transpose
595 {
596 "Transpose",
597 R"(HloModule Transpose_module
598 
599 ENTRY %Transpose.v2 () -> s32[1,2,3] {
600   %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
601   ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
602 }
603 
604 )"
605 },
606 {
607 "TransposeC128",
608 R"(HloModule TransposeC128_module
609 
610 ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
611   %input = c128[1,2,3]{2,1,0} parameter(0)
612   ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
613 }
614 
615 )"
616 },
617 // Triangular solve
618 {
619 "TriangularSolve",
620 R"(HloModule TriangularSolve_module
621 
622 ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
623   %a.1 = f32[4,4]{1,0} parameter(0)
624   %b.2 = f32[3,4]{1,0} parameter(1)
625   ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
626 }
627 
628 )"
629 },
630 // Dynamic slice
631 {
632 "DynamicSlice",
633 R"(HloModule DynamicSlice_module
634 
635 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
636   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
637   %constant = s32[1]{0} constant({0})
638   %start_index = s32[1]{0} parameter(1)
639   %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
640   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
641 }
642 
643 )"
644 },
645 // Dynamic slice with scalar indices
646 {
647 "DynamicSliceScalarIndices",
648 R"(HloModule DynamicSlice_module
649 
650 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
651   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
652   %constant = s32[] constant(0)
653   %start_index = s32[] parameter(1)
654   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
655 }
656 
657 )"
658 },
659 // Dynamic update slice
660 {
661 "DynamicUpdateSlice",
662 R"(HloModule DynamicSlice_module
663 
664 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
665   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
666   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
667   %start_indices = s32[4]{0} parameter(2)
668   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
669 }
670 
671 )"
672 },
673 // Dynamic update slice with scalar indices
674 {
675 "DynamicUpdateSliceScalarIndex",
676 R"(HloModule DynamicUpdateSlice_module
677 
678 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
679   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
680   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
681   %start_index.0 = s32[] parameter(2)
682   %start_index.1 = s32[] parameter(3)
683   %start_index.2 = s32[] parameter(4)
684   %start_index.3 = s32[] parameter(5)
685   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
686 }
687 
688 )"
689 },
690 // batch norm training
691 {
692 "BatchNormTraining",
693 R"(HloModule BasicTraining_module
694 
695 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
696   %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
697   %constant.1 = f32[2]{0} constant({2, 3})
698   %constant.2 = f32[2]{0} constant({1, 2})
699   ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
700 }
701 
702 )"
703 },
704 // batch norm inference
705 {
706 "BatchNormInference",
707 R"(HloModule BatchNormInference_module
708 
709 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
710   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
711   %offset = f32[2]{0} parameter(1)
712   %scale = f32[2]{0} parameter(2)
713   %mean = f32[2]{0} parameter(3)
714   %variance = f32[2]{0} parameter(4)
715   ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
716 }
717 
718 )"
719 },
720 // batch norm grad
721 {
722 "BatchNormGrad",
723 R"(HloModule BatchNormGrad_module
724 
725 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
726   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
727   %scale = f32[2]{0} parameter(1)
728   %mean = f32[2]{0} parameter(2)
729   %variance = f32[2]{0} parameter(3)
730   %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
731   ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
732 }
733 
734 )"
735 },
736 // fft
737 {
738 "Fft",
739 R"(HloModule Fft_module
740 
741 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
742   %input = c64[8,32]{1,0} parameter(0)
743   ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
744 }
745 
746 )"
747 },
748 // ifft
749 {
750 "Ifft2d",
751 R"(HloModule Ifft2d_module
752 
753 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
754   %input = c64[5,8,32]{2,1,0} parameter(0)
755   ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
756 }
757 
758 )"
759 },
760 // rfft2d
761 {
762 "Rfft2d",
763 R"(HloModule Rfft2d_module
764 
765 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
766   %input = f32[5,64,32]{2,1,0} parameter(0)
767   ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
768 }
769 
770 )"
771 },
772 // irfft3d
773 {
774 "Irfft3d",
775 R"(HloModule Irfft3d_module
776 
777 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
778   %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
779   ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
780 }
781 
782 )"
783 },
784 // pad
785 {
786 "Pad",
787 R"(HloModule Pad1DS3Array_module
788 
789 ENTRY %Pad1DS3Array.v3 () -> f32[7] {
790   %constant = f32[3]{0} constant({1, 2, 3})
791   %constant.1 = f32[] constant(0.1)
792   ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
793 }
794 
795 )"
796 },
797 // pad has interior
798 {
799 "PadHasInterior",
800 R"(HloModule PadHasInterior_module
801 
802 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
803   %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
804   %constant = f32[] constant(-5.123)
805   ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
806 }
807 
808 )"
809 },
810 // Negative padding
811 {
812 "PadHasNegativePadding",
813 R"(HloModule PadHasNegativePadding_module
814 
815 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
816   %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
817   %constant = f32[] constant(-5.123)
818   ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
819 }
820 
821 )"
822 },
823 // fusion
824 {
825 "Fusion",
826 R"(HloModule fusion_module
827 
828 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
829   %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
830   %constant.1.param_1 = f32[2]{0} parameter(1)
831   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
832   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
833 }
834 
835 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
836   %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
837   %constant.1 = f32[2]{0} constant({3.14, 4.25})
838   ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
839 }
840 
841 )"
842 },
843 {
844 "Gather",
845 R"(HloModule StringifyGather
846 
847 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
848   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
849   %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
850   ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
851 }
852 
853 )"
854 },
855 {
856 "SortedGather",
857 R"(HloModule StringifyGather
858 
859 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
860   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
861   %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
862   ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true
863 }
864 
865 )"
866 },
867 {
868 "Scatter",
869 R"(HloModule StringifyScatter
870 
871 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
872   %lhs = f32[] parameter(0)
873   %rhs = f32[] parameter(1)
874   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
875 }
876 
877 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
878   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
879   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
880   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
881   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
882 }
883 
884 )"
885 },
886 {
887 "SortedScatter",
888 R"(HloModule StringifySortedScatter
889 
890 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
891   %lhs = f32[] parameter(0)
892   %rhs = f32[] parameter(1)
893   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
894 }
895 
896 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
897   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
898   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
899   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
900   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
901 }
902 
903 )"
904 },
905 {
906 "UniqueIndicesScatter",
907 R"(HloModule StringifyUniqueIndicesScatter
908 
909 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
910   %lhs = f32[] parameter(0)
911   %rhs = f32[] parameter(1)
912   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
913 }
914 
915 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
916   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
917   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
918   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
919   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
920 }
921 
922 )"
923 },
924 {
925   "ConstantUnsignedNoUnderflow",
926   R"(HloModule ConstantUnsignedNoUnderflow_module
927 
928 ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
929   ROOT %constant = u64[] constant(1)
930 }
931 
932 )"
933 },
934 
935 {
936   "ConstantUnsignedNoOverflow",
937   R"(HloModule ConstantUnsignedNoOverflow_module
938 
939 ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
940   ROOT %constant = u64[] constant(9223372036854775807)
941 }
942 
943 )"
944 },
945 // CustomCallWithLayoutConstraints
946 {
947 "CustomCallWithLayoutConstraints",
948 R"(HloModule CustomCallWithLayoutConstraints
949 
950 ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
951   %p0 = f32[42,2,3]{0,1,2} parameter(0)
952   %p1 = f32[123,4]{0,1} parameter(1)
953   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
954 }
955 
956 )"
957 },
958 // CustomCallWithLayoutConstraintsNoOperands
959 {
960 "CustomCallWithLayoutConstraintsNoOperands",
961 R"(HloModule CustomCallWithLayoutConstraintsNoOperands
962 
963 ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
964   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
965 }
966 
967 )"
968 },
969 // CustomCallWithLayoutConstraintsTupleShapes
970 {
971 "CustomCallWithLayoutConstraintsTupleShapes",
972 R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
973 
974 ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
975   %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
976   %p1 = f32[123,4]{0,1} parameter(1)
977   ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
978 }
979 
980 )"
981 },
982 // CustomCallWithHasSideEffect
983 {
984 "CustomCallWithHasSideEffect",
985 R"(HloModule CustomCallWithHasSideEffect
986 
987 ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
988   %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
989   %p1 = f32[123,4]{0,1} parameter(1)
990   ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
991 }
992 
993 )"
994 },
995 // Parse c64 literal
996 {
997 "ParseC64Literal",
998 R"(HloModule ParseC64Literal
999 
1000 ENTRY %ParseC64Literal () -> c64[2] {
1001   ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
1002 }
1003 
1004 )"
1005 },
1006 // Parse c128 literal
1007 {
1008 "ParseC128Literal",
1009 R"(HloModule ParseC128Literal
1010 
1011 ENTRY %ParseC128Literal () -> c128[2] {
1012   ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
1013 }
1014 
1015 )"
1016 },
1017 // Indexed Conditional
1018 {
1019 "IndexedConditional",
1020 R"(HloModule indexed_conditional
1021 
1022 %Negate (x: f32[]) -> f32[] {
1023   %x = f32[] parameter(0)
1024   ROOT %negate = f32[] negate(f32[] %x)
1025 }
1026 
1027 %Identity (y: f32[]) -> f32[] {
1028   %y = f32[] parameter(0)
1029   ROOT %copy = f32[] copy(f32[] %y)
1030 }
1031 
1032 %Floor (z: f32[]) -> f32[] {
1033   %z = f32[] parameter(0)
1034   ROOT %floor = f32[] floor(f32[] %z)
1035 }
1036 
1037 ENTRY %Parameters1.v4 () -> f32[] {
1038   %constant = s32[] constant(1)
1039   %constant.1 = f32[] constant(56)
1040   %constant.2 = f32[] constant(12)
1041   %constant.3 = f32[] constant(13)
1042   ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
1043 }
1044 
1045 )"
1046 },
1047 // rng-get-and-update-state
1048 {
1049 "RngGetAndUpdateState",
1050 R"(HloModule rng_get_and_update_state
1051 
1052 ENTRY %RngGetAndUpdateState () -> u64[2] {
1053   ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=4096
1054 }
1055 
1056 )"
1057 },
1058   });
1059   // clang-format on
1060 }
1061 
1062 std::vector<TestData> CreateShortTestCases() {
1063   // clang-format off
1064   return std::vector<TestData>({
1065 // map
1066 {
1067 "Map",
1068 R"(HloModule MapBinaryAdder_module
1069 
1070 add_F32.v3 {
1071   lhs = f32[] parameter(0)
1072   rhs = f32[] parameter(1)
1073   ROOT add = f32[] add(lhs, rhs)
1074 }
1075 
1076 ENTRY MapBinaryAdder.v3 {
1077   param0 = f32[4]{0} parameter(0)
1078   param1 = f32[4]{0} parameter(1)
1079   ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
1080 }
1081 
1082 )"
1083 },
1084 // reduce
1085 {
1086 "Reduce",
1087 R"(HloModule ReduceR3ToR2_module
1088 
1089 add_F32.v3 {
1090   lhs = f32[] parameter(0)
1091   rhs = f32[] parameter(1)
1092   ROOT add = f32[] add(lhs, rhs)
1093 }
1094 
1095 ENTRY ReduceR3ToR2.v3 {
1096   input = f32[8,16,256]{2,1,0} parameter(0)
1097   constant = f32[] constant(0)
1098   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
1099 }
1100 
1101 )"
1102 },
1103 // tuple reduce
1104 {
1105 "TupleReduce",
1106 R"(HloModule TupleReduce
1107 
1108 max_argmax {
1109   value = f32[] parameter(2)
1110   prev_max = f32[] parameter(0)
1111   is_next_larger = pred[] compare(value, prev_max), direction=GE
1112   max = f32[] select(is_next_larger, value, prev_max)
1113   index = s32[] parameter(3)
1114   prev_argmax = s32[] parameter(1)
1115   argmax = s32[] select(is_next_larger, index, prev_argmax)
1116   ROOT pair = (f32[], s32[]) tuple(max, argmax)
1117 }
1118 
1119 ENTRY reduce_entry {
1120   values = f32[1024]{0} parameter(0)
1121   indices = s32[1024]{0} parameter(1)
1122   init_value = f32[] constant(-inf)
1123   init_index = s32[] constant(-1)
1124   ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
1125 }
1126 
1127 )"
1128 },
1129 // infeed/outfeed
1130 {
1131 "InfeedOutfeed",
1132 R"(HloModule outfeed_module
1133 
1134 ENTRY InfeedToOutfeed {
1135   token0 = token[] after-all()
1136   infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1137   infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
1138   outfeed = token[] outfeed(infeed.data, token0)
1139   ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1140   infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
1141   infeed.1.token = token[] get-tuple-element(infeed.1), index=1
1142   outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
1143 }
1144 
1145 )"
1146 },
1147 // Rng
1148 {
1149 "Rng",
1150 R"(HloModule rng_module
1151 
1152 ENTRY Rng {
1153   constant = f32[] constant(0)
1154   constant.1 = f32[] constant(1)
1155   ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
1156 }
1157 
1158 )"
1159 },
1160 // Reduce precision
1161 {
1162 "ReducePrecision",
1163 R"(HloModule reduce_precision
1164 
1165 ENTRY ReducePrecision {
1166   constant = f32[1]{0} constant({3.14159})
1167   ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
1168 }
1169 
1170 )"
1171 },
1172 // Sort (Key)
1173 {
1174 "SortKey",
1175 R"(HloModule sort
1176 
1177 compare {
1178   p.0.lhs = f32[] parameter(0)
1179   p.0.rhs = f32[] parameter(1)
1180   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1181 }
1182 
1183 ENTRY Sort {
1184   x = f32[1024]{0} parameter(0)
1185   ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
1186 }
1187 
1188 )"
1189 },
1190 // Sort (Key, Value)
1191 {
1192 "SortKeyValue",
1193 R"(HloModule sort
1194 
1195 compare {
1196   p.1.lhs = s32[] parameter(2)
1197   p.1.rhs = s32[] parameter(3)
1198   p.0.lhs = f32[] parameter(0)
1199   p.0.rhs = f32[] parameter(1)
1200   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1201 }
1202 
1203 ENTRY Sort {
1204   keys = f32[1024]{0} parameter(0)
1205   values = s32[1024]{0} parameter(1)
1206   ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1207 }
1208 
1209 )"
1210 },
1211 // R2 Sort (Key)
1212 {
1213 "SortKeyR2",
1214 R"(HloModule sort
1215 
1216 compare {
1217   p.0.lhs = f32[] parameter(0)
1218   p.0.rhs = f32[] parameter(1)
1219   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1220 }
1221 
1222 ENTRY Sort {
1223   x = f32[1024,16]{0,1} parameter(0)
1224   ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
1225 }
1226 
1227 )"
1228 },
1229 // R2 Sort (Key, Value)
1230 {
1231 "SortKeyValueR2",
1232 R"(HloModule sort
1233 
1234 compare {
1235   p.1.lhs = s32[] parameter(2)
1236   p.1.rhs = s32[] parameter(3)
1237   p.0.lhs = f32[] parameter(0)
1238   p.0.rhs = f32[] parameter(1)
1239   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1240 }
1241 
1242 ENTRY Sort {
1243   keys = f32[1024,16]{0,1} parameter(0)
1244   values = s32[1024,16]{0,1} parameter(1)
1245   ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
1246 }
1247 
1248 )"
1249 },
1250 // Sort (Key, Value, Value, Value)
1251 {
1252 "SortManyValues",
1253 R"(HloModule sort
1254 
1255 compare {
1256   p.1.lhs = s32[] parameter(2)
1257   p.1.rhs = s32[] parameter(3)
1258   p.2.lhs = u32[] parameter(4)
1259   p.2.rhs = u32[] parameter(5)
1260   p.3.lhs = f32[] parameter(6)
1261   p.3.rhs = f32[] parameter(7)
1262   p.0.lhs = f32[] parameter(0)
1263   p.0.rhs = f32[] parameter(1)
1264   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1265 }
1266 
1267 ENTRY Sort {
1268   keys = f32[1024,16]{0,1} parameter(0)
1269   values.0 = s32[1024,16]{0,1} parameter(1)
1270   values.1 = u32[1024,16]{0,1} parameter(2)
1271   values.2 = f32[1024,16]{0,1} parameter(3)
1272   ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
1273 }
1274 
1275 )"
1276 },
1277 // Sort (Key) is_stable=true
1278 {
1279 "SortKeyStable",
1280 R"(HloModule sort
1281 
1282 compare {
1283   p.0.lhs = f32[] parameter(0)
1284   p.0.rhs = f32[] parameter(1)
1285   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1286 }
1287 
1288 ENTRY Sort {
1289   x = f32[1024]{0} parameter(0)
1290   ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
1291 }
1292 
1293 )"
1294 },
1295 // Indexed Conditional
1296 {
1297 "IndexedConditional",
1298 R"(HloModule indexed_conditional
1299 
1300 Negate {
1301   x = f32[] parameter(0)
1302   ROOT negate = f32[] negate(x)
1303 }
1304 
1305 Identity {
1306   y = f32[] parameter(0)
1307   ROOT copy = f32[] copy(y)
1308 }
1309 
1310 Floor {
1311   z = f32[] parameter(0)
1312   ROOT floor = f32[] floor(z)
1313 }
1314 
1315 ENTRY Parameters1.v4 {
1316   constant = s32[] constant(1)
1317   constant.1 = f32[] constant(56)
1318   constant.2 = f32[] constant(12)
1319   constant.3 = f32[] constant(13)
1320   ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
1321 }
1322 
1323 )"
1324 },
1325 // Predicated Conditional
1326 {
1327 "PredicatedConditional",
1328 R"(HloModule pred_conditional
1329 
1330 Negate {
1331   x = f32[] parameter(0)
1332   ROOT negate = f32[] negate(x)
1333 }
1334 
1335 Identity {
1336   y = f32[] parameter(0)
1337   ROOT copy = f32[] copy(y)
1338 }
1339 
1340 ENTRY Parameters1.v4 {
1341   constant = pred[] constant(true)
1342   constant.1 = f32[] constant(56)
1343   constant.2 = f32[] constant(12)
1344   ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
1345 }
1346 
1347 )"
1348 },
1349 // CustomCall
1350 {
1351 "CustomCall",
1352 R"(HloModule custom_call
1353 
1354 ENTRY CustomCall {
1355   constant = f32[1]{0} constant({12345})
1356   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
1357 }
1358 
1359 )"
1360 },
1361 // Variables with non-default names
1362 {
1363 "NonDefaultNames",
1364 R"(HloModule add_constants_module
1365 
1366 ENTRY add_constants {
1367   foo = f32[] constant(3.14)
1368   ROOT bar = f32[] add(foo, foo)
1369 }
1370 
1371 )"
1372 },
1373 {
1374 "Dot",
1375 R"(HloModule dot
1376 
1377 ENTRY dot {
1378   a = f32[2,10]{1,0} parameter(0)
1379   b = f32[10,2]{1,0} parameter(1)
1380   ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
1381 }
1382 
1383 )"
1384 },
1385 {
1386 "gather",
1387 R"(HloModule gather
1388 
1389 ENTRY Gather {
1390   input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1391   start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1392   ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
1393 }
1394 
1395 )"
1396 },
1397 // all-reduce
1398 {
1399 "AllReduce",
1400 R"(HloModule CRS
1401 
1402 add {
1403   lhs = f32[] parameter(0)
1404   rhs = f32[] parameter(1)
1405   ROOT add = f32[] add(lhs, rhs)
1406 }
1407 
1408 ENTRY CRS {
1409   input = f32[8]{0} parameter(0)
1410   ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
1411 }
1412 
1413 )"
1414 },
1415 // all-reduce with subgroups
1416 {
1417 "AllReduceWithSubgroups",
1418 R"(HloModule CRS_Subgroups
1419 
1420 add {
1421   lhs = f32[] parameter(0)
1422   rhs = f32[] parameter(1)
1423   ROOT add = f32[] add(lhs, rhs)
1424 }
1425 
1426 ENTRY AllReduceWithSubgroups {
1427   input = f32[128,32]{0,1} parameter(0)
1428   ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
1429 }
1430 
1431 )"
1432 },
1433 // all-reduce with constrained layout
1434 {
1435 "AllReduceWithLayout",
1436 R"(HloModule CRS
1437 
1438 add {
1439   lhs = f32[] parameter(0)
1440   rhs = f32[] parameter(1)
1441   ROOT add = f32[] add(lhs, rhs)
1442 }
1443 
1444 ENTRY CRS {
1445   input = f32[8]{0} parameter(0)
1446   ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
1447 }
1448 
1449 )"
1450 },
1451 // all-reduce with all-reduce-id
1452 {
1453 "AllReduceAllReduce",
1454 R"(HloModule CRS
1455 
1456 add {
1457   lhs = f32[] parameter(0)
1458   rhs = f32[] parameter(1)
1459   ROOT add = f32[] add(lhs, rhs)
1460 }
1461 
1462 ENTRY CRS {
1463   input = f32[8]{0} parameter(0)
1464   crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1465   ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1466 }
1467 
1468 )"
1469 },
1470 // all-to-all
1471 {
1472 "AllToAll",
1473 R"(HloModule AllToAll
1474 
1475 ENTRY AllToAll {
1476   input = f32[128,32]{0,1} parameter(0)
1477   ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
1478 }
1479 
1480 )"
1481 },
1482 // all-to-all with subgroups
1483 {
1484 "AllToAllWithSubgroups",
1485 R"(HloModule AllToAllWithSubgroups
1486 
1487 ENTRY AllToAllWithSubgroups {
1488   p0 = f32[128,32]{0,1} parameter(0)
1489   p1 = f32[128,32]{0,1} parameter(1)
1490   ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
1491 }
1492 
1493 )"
1494 },
1495 // collective-permute
1496 {
1497 "CollectivePermute",
1498 R"(HloModule CollectivePermute
1499 
1500 ENTRY CollectivePermute {
1501   input = f32[128,32]{0,1} parameter(0)
1502   ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
1503 }
1504 
1505 )"
1506 },
1507 // replica-id
1508 {
1509 "ReplicaId",
1510 R"(HloModule replica-id
1511 
1512 ENTRY Replica-id {
1513   ROOT replica-id = u32[] replica-id()
1514 }
1515 
1516 )"
1517 },
1518 // partition-id
1519 {
1520 "PartitionId",
1521 R"(HloModule partition-id
1522 
1523 ENTRY PartitionId {
1524   ROOT id = u32[] partition-id()
1525 }
1526 
1527 )"
1528 },
1529 // Iota
1530 {
1531 "Iota",
1532 R"(HloModule iota
1533 
1534 ENTRY Iota {
1535   ROOT iota = f32[100]{0} iota(), iota_dimension=0
1536 }
1537 
1538 )"
1539 },
1540 // custom-call with window, dim_labels and feature_group_count
1541 {
1542 "CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
1543 R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
1544 
1545 ENTRY Computation {
1546   ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
1547 }
1548 
1549 )"
1550     },
1551 // is_scheduled=true attribute
1552 {
1553 "ScheduledModule",
1554 R"(HloModule scheduled_module, is_scheduled=true
1555 
1556 compare {
1557   p.1.lhs = s32[] parameter(2)
1558   p.1.rhs = s32[] parameter(3)
1559   p.0.lhs = f32[] parameter(0)
1560   p.0.rhs = f32[] parameter(1)
1561   ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1562 }
1563 
1564 ENTRY Sort {
1565   keys = f32[1024]{0} parameter(0)
1566   values = s32[1024]{0} parameter(1)
1567   ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1568 }
1569 
1570 )"
1571     },
1572 // AfterAll with multiple operands
1573 {
1574 "AfterAllWithMultipleOperands",
1575 R"(HloModule AfterAllWithMultipleOperands
1576 
1577 ENTRY AfterAllWithMultipleOperands {
1578   p0 = f32[] parameter(0)
1579   token0 = token[] after-all()
1580   token1 = token[] after-all()
1581   ROOT after-all = token[] after-all(p0, token0, token1)
1582 }
1583 
1584 )"
1585 },
1586 // AddDependency
1587 // A dependency chain is created from 'neg' to 'exp' using tokens.
1588 {
1589 "AddDependency",
1590 R"(HloModule AddDependency
1591 
1592 ENTRY AddDependency {
1593   p = f32[] parameter(0)
1594   neg = f32[] negate(p)
1595   token0 = token[] after-all(neg)
1596   p_after_token = f32[] add-dependency(p, token0)
1597   exp = f32[] exponential(p_after_token)
1598   ROOT sum = f32[] add(neg, exp)
1599 }
1600 
1601 )"
1602 },
1603 
1604 // A module containing constants equal to the min/max values of various data
1605 // types.
1606 {
1607 "MinMaxValues",
1608 R"(HloModule MinMaxValues
1609 
1610 ENTRY MinMaxValues {
1611   x.s8 = s8[2]{0} constant({-128, 127})
1612   x.s16 = s16[2]{0} constant({-32768, 32767})
1613   x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
1614   x.u8 = u8[2]{0} constant({0, 255})
1615   x.u16 = u16[2]{0} constant({0, 65535})
1616   x.u32 = u32[2]{0} constant({0, 4294967295})
1617   x.f16 = f16[2]{0} constant({-65504, 65504})
1618   x.bf16 = bf16[2]{0} constant({-3.39e+38, 3.39e+38})
1619   x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
1620   x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
1621   x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
1622   ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
1623 }
1624 
1625 )"
1626 },
1627 
1628 // Bitcast-convert usage
1629 {
1630 "BitcastConvert",
1631 R"(HloModule BitcastConvert
1632 
1633 ENTRY BitcastConvertUsage {
1634   p = f32[100]{0} parameter(0)
1635   ROOT out = u32[100]{0} bitcast-convert(p)
1636 }
1637 
1638 )"
1639 },
1640 {
1641 "OuterDimensionPartitions",
1642 R"(HloModule OuterDimensionPartitions
1643 
1644 ENTRY Test {
1645   ROOT foo = f32[100]{0} parameter(0), outer_dimension_partitions={0,10,20}
1646 }
1647 
1648 )"
1649 },
1650 });
1651   // clang-format on
1652 }
1653 
1654 // The test class for those tests defined above which round-trip through the
1655 // parser and ToString is templatized on two bool parameters:
1656 //
1657 //  short_form : used for the "short" test cases which use the ShortParsable
1658 //    output form.
1659 //  proto_round_trip : whether the module should also be round-tripped through
1660 //    HloProto form. This provides much better coverage for the proto
1661 //    serialization/deserialization.
1662 //
1663 // The proto_round_trip=true case also technically covers the Parser->ToString
1664 // roundtrip as well, but separating out the Parser->ToString roundtrip as its
1665 // own test provides better isolation and could conceivably catch weirdo bugs
1666 // which are hidden by interaction between the textual and proto roundtripping.
1667 template <bool short_form, bool proto_round_trip>
1668 class HloParameterizedParserTest
1669     : public ::testing::Test,
1670       public ::testing::WithParamInterface<TestData> {
1671  protected:
1672   // Expects "ToString(ParseHloModule(string)) == string", that is, parses the
1673   // string, asserts that it succeeded, stringifies the parsed module, and
1674   // checks that it equals the original string.
1675   void ExpectEqual() {
1676     std::unique_ptr<HloModule> module;
1677     const string& original = GetParam().module_string;
1678     if (GetParam().enable_verification) {
1679       auto verified_module = absl::make_unique<VerifiedHloModule>(
1680           GetParam().test_name, HloModuleConfig(),
1681           /*verifier_layout_sensitive=*/false,
1682           /*allow_mixed_precision_in_hlo_verifier=*/true,
1683           ShapeUtil::ByteSizeOfElements);
1684       TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
1685       module = std::move(verified_module);
1686     } else {
1687       TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
1688     }
1689     if (proto_round_trip) {
1690       TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
1691                                           module->ToProto(), module->config()));
1692     }
1693     if (short_form) {
1694       EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
1695     } else {
1696       EXPECT_EQ(
1697           original,
1698           module->ToString(HloPrintOptions().set_print_large_constants(true)));
1699     }
1700   }
1701 };
1702 
1703 // These using shenanigans are required because the TEST_P macro doesn't like
1704 // template instantiations which contain commas.
1705 using HloParserTestLong = HloParameterizedParserTest<false, false>;
1706 using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
1707 using HloParserTestShort = HloParameterizedParserTest<true, false>;
1708 using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
1709 
1710 TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
1711 TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
1712 TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
1713 TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
1714 
1715 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
1716                          ::testing::ValuesIn(CreateTestCases()),
1717                          TestDataToString);
1718 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1719                          HloParserTestLongProto,
1720                          ::testing::ValuesIn(CreateTestCases()),
1721                          TestDataToString);
1722 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
1723                          ::testing::ValuesIn(CreateShortTestCases()),
1724                          TestDataToString);
1725 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1726                          HloParserTestShortProto,
1727                          ::testing::ValuesIn(CreateShortTestCases()),
1728                          TestDataToString);
1729 
1730 class HloParserTest : public ::testing::Test {
1731  protected:
1732   static void ExpectHasSubstr(string_view s, string_view expected) {
1733     EXPECT_TRUE(absl::StrContains(s, expected))
1734         << "'" << s << "' does not contain '" << expected << "'";
1735   }
1736   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
1737       absl::string_view hlo_text) {
1738     auto module = absl::make_unique<VerifiedHloModule>(
1739         ::testing::UnitTest::GetInstance()->current_test_info()->name(),
1740         HloModuleConfig(),
1741         /*verifier_layout_sensitive=*/false,
1742         /*allow_mixed_precision_in_hlo_verifier=*/true,
1743         ShapeUtil::ByteSizeOfElements);
1744     TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
1745     return std::move(module);
1746   }
1747 };
1748 
1749 TEST_F(HloParserTest, Empty) {
1750   const string original = "";
1751   auto result = ParseAndReturnUnverifiedModule(original);
1752   EXPECT_NE(Status::OK(), result.status());
1753 }
1754 
1755 TEST_F(HloParserTest, Garbage) {
1756   const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
1757   auto result = ParseAndReturnUnverifiedModule(original);
1758   EXPECT_NE(Status::OK(), result.status());
1759 }
1760 
1761 TEST_F(HloParserTest, WrongOpcode) {
1762   const string original = R"(HloModule wrong_opcode:
1763 
1764 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
1765   %x = f32[]{} parameter(0)
1766   %y = f32[]{} parameter(1)
1767   %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
1768 }
1769 
1770 )";
1771   auto result = ParseAndReturnUnverifiedModule(original);
1772   EXPECT_NE(Status::OK(), result.status());
1773 }
1774 
1775 TEST_F(HloParserTest, MetadataWithCholesky) {
1776   const string original = R"(HloModule metadata_with_cholesky
1777 ENTRY %blabla (a: f32[1,291,291]) -> f32[1,291,291] {
1778   %a = f32[1,291,291] parameter(0)
1779   %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true, metadata={op_type="Cholesky" op_name="Cholesky"}
1780 }
1781 )";
1782   auto result = ParseAndReturnVerifiedModule(original);
1783   EXPECT_EQ(Status::OK(), result.status());
1784   EXPECT_EQ("Cholesky", result.ValueOrDie()
1785                             ->entry_computation()
1786                             ->root_instruction()
1787                             ->metadata()
1788                             .op_name());
1789   EXPECT_EQ("Cholesky", result.ValueOrDie()
1790                             ->entry_computation()
1791                             ->root_instruction()
1792                             ->metadata()
1793                             .op_type());
1794 }
1795 
1796 TEST_F(HloParserTest, WrongShape) {
1797   const string original = R"(HloModule wrong_opcode:
1798 
1799 ENTRY %blabla (x: g32[]) -> g32[] {
1800   %x = g32[]{} parameter(0)
1801 }
1802 
1803 )";
1804   auto result = ParseAndReturnUnverifiedModule(original);
1805   EXPECT_NE(Status::OK(), result.status());
1806 }
1807 
1808 TEST_F(HloParserTest, WrongOperandsSize) {
1809   const string original = R"(HloModule wrong_opcode:
1810 
1811 ENTRY %blabla (x: f32[]) -> pred[] {
1812   %x = f32[]{} parameter(0)
1813   %eq = pred[]{} compare(f32[]{} %x), direction=EQ
1814 }
1815 
1816 )";
1817   auto result = ParseAndReturnUnverifiedModule(original);
1818   EXPECT_NE(Status::OK(), result.status());
1819 }
1820 
1821 TEST_F(HloParserTest, OperandNotFound) {
1822   const string original = R"(HloModule operand_not_found:
1823 ENTRY %blabla (x: f32[]) -> pred[] {
1824   %x = f32[]{} parameter(0)
1825   %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
1826 }
1827 )";
1828   auto result = ParseAndReturnUnverifiedModule(original);
1829   EXPECT_NE(Status::OK(), result.status());
1830 }
1831 
1832 TEST_F(HloParserTest, MoreConstants) {
1833   const string original = R"(HloModule SelectScalarS32True_module
1834 
1835 ENTRY %SelectScalarS32True.v4 () -> s32[] {
1836   %constant.2 = pred[] constant(true)
1837   %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
1838   %constant = s32[] constant(42)
1839   %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
1840 }
1841 
1842 )";
1843   auto result = ParseAndReturnVerifiedModule(original);
1844   TF_EXPECT_OK(result.status());
1845   // Constant instructions have no name. The string will be parsed successfully
1846   // but the constant names will not be exactly the same.
1847 }
1848 
1849 TEST_F(HloParserTest, ConfigurationField) {
1850   const string original = R"(HloModule AModule
1851 ENTRY %configuration_test() -> s32[] {
1852   %constant = s32[] constant(42), backend_config="foo bar"
1853 })";
1854   auto result = ParseAndReturnVerifiedModule(original);
1855   TF_ASSERT_OK(result.status());
1856   EXPECT_EQ("foo bar", result.ValueOrDie()
1857                            ->entry_computation()
1858                            ->root_instruction()
1859                            ->raw_backend_config_string());
1860 }
1861 
1862 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
1863   const string original = R"(HloModule some_2_module
1864 
1865 ENTRY %some_2 () -> f32[2] {
1866   ROOT %constant = f32[2]{0} constant({1,{2}})
1867 }
1868 
1869 )";
1870   auto result = ParseAndReturnUnverifiedModule(original);
1871   EXPECT_NE(Status::OK(), result.status());
1872   ExpectHasSubstr(result.status().error_message(),
1873                   "expects nested array in rank 1, but sees larger");
1874 }
1875 
1876 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
1877   const string original = R"(HloModule some_2x3_module
1878 
1879 ENTRY %some_2x3 () -> f32[2,3] {
1880   ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
1881 }
1882 
1883 )";
1884   auto result = ParseAndReturnUnverifiedModule(original);
1885   EXPECT_NE(Status::OK(), result.status());
1886   ExpectHasSubstr(result.status().error_message(),
1887                   "expects nested array in rank 2, but sees 1");
1888 }
1889 
1890 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
1891   const string original = R"(HloModule some_2x3x2_module
1892 
1893 ENTRY %some_2x3x2 () -> f32[2,3,2] {
1894   ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
1895 }
1896 
1897 )";
1898   auto result = ParseAndReturnUnverifiedModule(original);
1899   EXPECT_NE(Status::OK(), result.status());
1900   ExpectHasSubstr(result.status().error_message(),
1901                   "expects 3 elements in the [0]th element");
1902 }
1903 
1904 TEST_F(HloParserTest, ConstantF16Overflow) {
1905   const string original =
1906       R"(HloModule ConstantF16Overflow_module
1907 
1908 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
1909   ROOT %constant = f16[] constant(-65520)
1910 }
1911 
1912 )";
1913   auto result = ParseAndReturnUnverifiedModule(original);
1914   EXPECT_NE(Status::OK(), result.status());
1915   ExpectHasSubstr(result.status().error_message(),
1916                   "is out of range for literal's primitive type F16");
1917 }
1918 
1919 TEST_F(HloParserTest, ConstantBf16NoOverflow) {
1920   // 65505 is in range for bf16.
1921   const string original = R"(
1922   HloModule test_module
1923   ENTRY test {
1924     ROOT c = bf16[] constant(-65505)
1925   })";
1926   EXPECT_EQ(Status::OK(), ParseAndReturnVerifiedModule(original).status());
1927 }
1928 
1929 TEST_F(HloParserTest, ConstantBf16Overflow) {
1930   // 1e100 is out of range for bf16.
1931   const string original = R"(
1932   HloModule test_module
1933   ENTRY test {
1934     ROOT c = bf16[] constant(1e100)
1935   })";
1936   ExpectHasSubstr(
1937       ParseAndReturnUnverifiedModule(original).status().error_message(),
1938       "out of range");
1939 }
1940 
1941 TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
1942   const string original = R"(
1943       HloModule ConstantUnsignedUnderflow_module
1944       ENTRY %ConstantUnsignedUnderflow () -> u64[] {
1945         ROOT %constant = u64[] constant(-1)
1946       })";
1947   auto result = ParseAndReturnUnverifiedModule(original);
1948   EXPECT_NE(Status::OK(), result.status());
1949   ExpectHasSubstr(result.status().error_message(),
1950                   "is out of range for literal's primitive type U64");
1951 }
1952 
1953 TEST_F(HloParserTest, ConstantUnsignedOverflow) {
1954   const string original = R"(
1955       HloModule ConstantUnsignedOverflow_module
1956       ENTRY %ConstantUnsignedOverflow () -> u32[] {
1957         ROOT %constant = u32[] constant(4294967296)
1958       })";
1959   auto result = ParseAndReturnUnverifiedModule(original);
1960   EXPECT_NE(Status::OK(), result.status());
1961   ExpectHasSubstr(result.status().error_message(),
1962                   "is out of range for literal's primitive type U32");
1963 }
1964 
1965 TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
1966   const string original = R"(
1967       HloModule ConstantUnsignedOverflow_module
1968       ENTRY %ConstantUnsignedOverflow () -> u64[] {
1969         ROOT %constant = u64[] constant(9223372036854775808)
1970       })";
1971   auto result = ParseAndReturnUnverifiedModule(original);
1972   EXPECT_NE(Status::OK(), result.status());
1973 }
1974 
1975 TEST_F(HloParserTest, ConstantC64Overflow) {
1976   const string original = R"(
1977       HloModule test_module
1978       ENTRY test () -> c64[] {
1979         ROOT c = c64[] constant((1e100, 0))
1980       })";
1981   auto result = ParseAndReturnUnverifiedModule(original);
1982   EXPECT_NE(Status::OK(), result.status());
1983 }
1984 
1985 TEST_F(HloParserTest, ConstantC64Underflow) {
1986   const string original = R"(
1987       HloModule test_module
1988       ENTRY test () -> c64[] {
1989         ROOT c = c64[] constant((0, -1e100))
1990       })";
1991   auto result = ParseAndReturnUnverifiedModule(original);
1992   EXPECT_NE(Status::OK(), result.status());
1993 }
1994 
1995 TEST_F(HloParserTest, ConstantF64Overflow) {
1996   const string original = R"(
1997       HloModule test_module
1998       ENTRY test {
1999         ROOT c = f64[] constant(1.8e308)
2000       })";
2001   auto result = ParseAndReturnUnverifiedModule(original);
2002   EXPECT_NE(Status::OK(), result.status());
2003 }
2004 
2005 TEST_F(HloParserTest, ConstantF64Underflow) {
2006   const string original = R"(
2007       HloModule test_module
2008       ENTRY test {
2009         ROOT c = f64[] constant(-1.8e308)
2010       })";
2011   auto result = ParseAndReturnUnverifiedModule(original);
2012   EXPECT_NE(Status::OK(), result.status());
2013 }
2014 
2015 TEST_F(HloParserTest, ConstantWithExp) {
2016   const string original = R"(HloModule ConstantWithExp_module
2017 
2018 ENTRY %ConstantWithExp.v4 () -> f32[] {
2019   %constant.1 = f32[] constant(3e+2)
2020 }
2021 
2022 )";
2023   auto result = ParseAndReturnVerifiedModule(original);
2024   TF_EXPECT_OK(result.status());
2025   // The string will be parsed successfully but the output strings are not
2026   // exactly the same, because "3e2" is parsed into value 300 and will be
2027   // printed as "300".
2028 }
2029 
2030 TEST_F(HloParserTest, ShortConstant) {
2031   const string original = R"(HloModule ShortConstant_module
2032 
2033 ENTRY %ShortConstant.v4 () -> f32[67,89] {
2034   ROOT %constant.1 = f32[67,89]{1,0} constant({...})
2035 }
2036 
2037 )";
2038   auto result = ParseAndReturnVerifiedModule(original);
2039   TF_EXPECT_OK(result.status());
2040   EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2041 }
2042 
2043 TEST_F(HloParserTest, AttributesAnyOrder) {
2044   const string original = R"(HloModule any_order_module
2045 
2046 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
2047   %input = f32[1,2,1]{2,1,0} parameter(0)
2048   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2049   %filter = f32[1,1,1]{2,1,0} parameter(1)
2050   ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
2051 }
2052 
2053 )";
2054   TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2055 }
2056 
2057 TEST_F(HloParserTest, InvalidDimLabels) {
2058   string prefix = R"(HloModule invalid_dim_labels_module
2059 
2060 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2061   %input = f32[1,2,1]{2,1,0} parameter(0)
2062   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2063   %filter = f32[1,1,1]{2,1,0} parameter(1)
2064   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
2065   string suffix = R"(
2066 }
2067 
2068 )";
2069 
2070   ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2071                       absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
2072                       .status()
2073                       .error_message(),
2074                   "expects dim labels pattern");
2075 
2076   ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2077                       absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
2078                       .status()
2079                       .error_message(),
2080                   "must have the same rank");
2081 }
2082 
2083 TEST_F(HloParserTest, UnexpectedAttribute) {
2084   const string original = R"(HloModule unexpected_attr_module
2085 
2086 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2087   %token0 = token[] after-all()
2088   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2089   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2090   ROOT %constant = f32[] constant(2.1)
2091   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
2092   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2093 }
2094 
2095 )";
2096   ExpectHasSubstr(
2097       ParseAndReturnUnverifiedModule(original).status().error_message(),
2098       "unexpected attribute \"calls\"");
2099 }
2100 
2101 TEST_F(HloParserTest, MissingAttribute) {
2102   const string original = R"(HloModule missing_attr_module
2103 
2104 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2105   %token0 = token[] after-all()
2106   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2107   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2108   ROOT %constant = f32[] constant(-2.1)
2109   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
2110   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2111 }
2112 
2113 )";
2114   ExpectHasSubstr(
2115       ParseAndReturnUnverifiedModule(original).status().error_message(),
2116       "attribute channel_id is expected but not seen");
2117 }
2118 
2119 TEST_F(HloParserTest, PredecessorUndefined) {
2120   const string original = R"(HloModule pre_not_found_module
2121 
2122 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2123   %token0 = token[] after-all()
2124   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2125   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2126   ROOT %constant = f32[] constant(2.1)
2127   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
2128   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2129 }
2130 
2131 )";
2132   ExpectHasSubstr(
2133       ParseAndReturnUnverifiedModule(original).status().error_message(),
2134       "'done' is not defined");
2135 }
2136 
2137 TEST_F(HloParserTest, SliceAllowOmitStride1) {
2138   const string original = R"(HloModule slice_module
2139 
2140 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
2141   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
2142   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
2143 }
2144 
2145 )";
2146   TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2147 }
2148 
2149 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
2150   const string original = R"(HloModule window_pad_module
2151 
2152 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2153   %input = f32[1,2,1]{2,1,0} parameter(0)
2154   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2155   %filter = f32[1,1,1]{2,1,0} parameter(1)
2156   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
2157 }
2158 
2159 )";
2160   ExpectHasSubstr(
2161       ParseAndReturnUnverifiedModule(original).status().error_message(),
2162       "expects padding_low and padding_high separated by '_'");
2163 }
2164 
2165 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
2166   const string original = R"(HloModule test_comma_module
2167 
2168 ENTRY %test_comma.v4 () -> f32[] {
2169   ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
2170 }
2171 
2172 )";
2173   TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2174 }
2175 
2176 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
2177   const string original = R"(HloModule custom_call:
2178 
2179 ENTRY %CustomCall () -> f32[1] {
2180   %constant = f32[1]{0} constant({12345})
2181   ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
2182 })";
2183   ExpectHasSubstr(
2184       ParseAndReturnUnverifiedModule(original).status().error_message(),
2185       "Shape of computation CustomCall, f32[1], is not compatible "
2186       "with that of its root instruction foo, f32[1,2,3]");
2187 }
2188 
2189 TEST_F(HloParserTest, EntryComputationWithLayout) {
2190   const string original = R"(HloModule layout:
2191 add_F32.v3 {
2192   lhs = f32[] parameter(0)
2193   rhs = f32[] parameter(1)
2194   ROOT add = f32[] add(lhs, rhs)
2195 }
2196 
2197 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
2198   input = f32[8,16,256]{0,1,2} parameter(0)
2199   constant = f32[] constant(0)
2200   ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
2201 })";
2202 
2203   auto module = ParseAndReturnVerifiedModule(original);
2204   TF_ASSERT_OK(module.status());
2205   auto program_layout = module.ValueOrDie()->entry_computation_layout();
2206   ASSERT_EQ(program_layout.parameter_count(), 1);
2207   auto param_layout = program_layout.parameter_layout(0).layout();
2208   auto result_layout = program_layout.result_layout().layout();
2209   EXPECT_TRUE(
2210       LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
2211       << "actual layout of parameter(0) is "
2212       << LayoutUtil::HumanString(param_layout);
2213   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
2214       << "actual layout of result is "
2215       << LayoutUtil::HumanString(result_layout);
2216 }
2217 
2218 TEST_F(HloParserTest, NoEntry) {
2219   const string original = R"(HloModule no_entry:
2220 c1 {
2221   const1 = f32[1]{0} constant({12345})
2222 }
2223 c2 {
2224   const2 = f32[1]{0} constant({67890})
2225 })";
2226   auto module = ParseAndReturnVerifiedModule(original);
2227   TF_ASSERT_OK(module.status());
2228   EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
2229 }
2230 
2231 TEST_F(HloParserTest, NoRoot) {
2232   const string original = R"(HloModule no_root:
2233 ENTRY consts {
2234   first = f32[1]{0} constant({12345})
2235   last = f32[1]{0} constant({67890})
2236 })";
2237   auto module = ParseAndReturnVerifiedModule(original);
2238   TF_ASSERT_OK(module.status());
2239   EXPECT_EQ(
2240       module.ValueOrDie()->entry_computation()->root_instruction()->name(),
2241       "last");
2242 }
2243 
2244 TEST_F(HloParserTest, Comments) {
2245   const string original = R"(/* module description. */
2246 HloModule comments:
2247 
2248 ENTRY /*comment*/ c1 {
2249   /* blah */
2250   ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
2251   /* comment */
2252 }
2253 
2254 /* something else */
2255 
2256 )";
2257   auto module = ParseAndReturnVerifiedModule(original);
2258   TF_ASSERT_OK(module.status());
2259 }
2260 
2261 TEST_F(HloParserTest, MultilineComments) {
2262   const string original = R"(HloModule multiline_comment:
2263 ENTRY c1 {
2264   /*
2265      ROOT foo = f32[1]{0} constant({12345})
2266   */
2267   ROOT const1 = f32[1]{0} constant({12345})
2268 /*
2269 a
2270 b
2271 c
2272 d
2273 
2274 */
2275 })";
2276   auto module = ParseAndReturnVerifiedModule(original);
2277   TF_ASSERT_OK(module.status());
2278 }
2279 
2280 TEST_F(HloParserTest, UnterminatedComment) {
2281   const string original = R"(HloModule unterminated_comment:
2282 ENTRY c1 {
2283 /* unterminated
2284   ROOT const1 = f32[1]{0} constant({12345})
2285 })";
2286   // Verify that the error message points to the beginning of the unterminated
2287   // comment.
2288   ExpectHasSubstr(
2289       ParseAndReturnUnverifiedModule(original).status().error_message(),
2290       "/* unterminated\n^");
2291 }
2292 
2293 TEST_F(HloParserTest, SlashSlashComments) {
2294   const string original = R"(HloModule slash_slash_comment:
2295 // Garbage
2296 ENTRY c1 {
2297   // Foo bar
2298   ROOT const1 = f32[1]{0} constant({12345}) // Something else
2299 })";
2300   auto module = ParseAndReturnVerifiedModule(original);
2301   TF_ASSERT_OK(module.status());
2302 }
2303 
2304 TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
2305   const string original =
2306       "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
2307       "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
2308   auto module = ParseAndReturnVerifiedModule(original);
2309   TF_ASSERT_OK(module.status());
2310 }
2311 
2312 TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
2313   const string original =
2314       "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
2315       "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
2316   auto module = ParseAndReturnVerifiedModule(original);
2317   TF_ASSERT_OK(module.status());
2318 }
2319 
2320 TEST_F(HloParserTest, MultipleEntries) {
2321   const string original = R"(HloModule multiple_entries:
2322 ENTRY c1 {
2323   const1 = f32[1]{0} constant({12345})
2324 }
2325 ENTRY c2 {
2326   const2 = f32[1]{0} constant({67890})
2327 })";
2328   ExpectHasSubstr(
2329       ParseAndReturnUnverifiedModule(original).status().error_message(),
2330       "expects only one ENTRY");
2331 }
2332 
2333 TEST_F(HloParserTest, MultipleRoots) {
2334   const string original = R"(HloModule multiple_roots:
2335 ENTRY consts {
2336   ROOT const1 = f32[1]{0} constant({12345})
2337   ROOT const2 = f32[1]{0} constant({12345})
2338 })";
2339   ExpectHasSubstr(
2340       ParseAndReturnUnverifiedModule(original).status().error_message(),
2341       "one computation should have only one ROOT");
2342 }
2343 
2344 TEST_F(HloParserTest, ComputationExists) {
2345   const string original = R"(HloModule comp_exists
2346 comp {
2347   const1 = f32[1]{0} constant({12345})
2348 }
2349 comp {
2350   const2 = f32[1]{0} constant({67890})
2351 })";
2352   ExpectHasSubstr(
2353       ParseAndReturnUnverifiedModule(original).status().error_message(),
2354       R"(was parsing 2:1: error: computation previously defined here
2355 comp {
2356 ^)");
2357 }
2358 
2359 TEST_F(HloParserTest, CrossComputationLookup) {
2360   const string original = R"(HloModule cross_computation_lookup:
2361 tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
2362   ROOT aparam = (s32[], s32[]) parameter(0)
2363 }
2364 
2365 tcallb (b: (s32[], s32[])) -> s32[] {
2366   rparam = (s32[], s32[]) parameter(0)
2367   ROOT gte0 = s32[] get-tuple-element(aparam), index=0
2368 }
2369 
2370 ENTRY entry {
2371   param = (s32[], s32[]) parameter(0)
2372   call0 = (s32[], s32[]) call(param), to_apply=tcalla
2373   ROOT call1 = s32[] call(param), to_apply=tcallb
2374 })";
2375   ExpectHasSubstr(
2376       ParseAndReturnUnverifiedModule(original).status().error_message(),
2377       "was parsing 8:39: error: instruction does not exist: aparam");
2378 }
2379 
2380 TEST_F(HloParserTest, SameNameDiffComputations) {
2381   const string original = R"(HloModule same_names:
2382 add {
2383   p0 = f32[] parameter(0)
2384   p1 = f32[] parameter(1)
2385   ROOT result = f32[] add(p0, p1)
2386 }
2387 
2388 ENTRY ReduceR3ToR2 {
2389   p0 = f32[8,16,256]{2,1,0} parameter(0)
2390   p1 = f32[] constant(0)
2391   ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
2392 }
2393 )";
2394   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
2395   ASSERT_NE(module->entry_computation(), nullptr);
2396   EXPECT_THAT(module->entry_computation()->root_instruction(),
2397               GmockMatch(m::Reduce()));
2398 }
2399 
2400 TEST_F(HloParserTest, ParseSharding) {
2401   const string original = "{maximal device=42}";
2402   TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
2403   EXPECT_EQ(sharding.ToString(), original);
2404 }
2405 
2406 TEST_F(HloParserTest, ParseFrontendAttributes) {
2407   const string original = "{attr_a=test_a,attr_b=b}";
2408   TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
2409                           ParseFrontendAttributes(original));
2410   EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
2411 }
2412 
2413 TEST_F(HloParserTest, ParseWindow) {
2414   Window original = window_util::MakeWindow({1, 2, 3});
2415   TF_ASSERT_OK_AND_ASSIGN(Window parsed,
2416                           ParseWindow(window_util::ToString(original)))
2417   EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
2418 }
2419 
2420 TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
2421   const string original = "b0f_0io->b0f";
2422   TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
2423                           ParseConvolutionDimensionNumbers(original));
2424   EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
2425 }
2426 
2427 TEST_F(HloParserTest, ParseReplicaGroups) {
2428   const string original = "{{0,1},{2,3}}";
2429   TF_ASSERT_OK_AND_ASSIGN(std::vector<ReplicaGroup> replica_groups,
2430                           ParseReplicaGroupsOnly(original));
2431   EXPECT_EQ(original, ReplicaGroupsToString(replica_groups));
2432 }
2433 
2434 TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
2435   const string original = "0_1x2_3";
2436   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2437   EXPECT_EQ(original, PaddingConfigToString(dnums));
2438 }
2439 
2440 TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
2441   const string original = "0_1_0x2_3_4";
2442   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2443   EXPECT_EQ(original, PaddingConfigToString(dnums));
2444 }
2445 
2446 TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
2447   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
2448   // The extra "_0" gets added to the canonical string because the other dim has
2449   // interior padding.
2450   EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
2451 }
2452 
2453 TEST_F(HloParserTest, NontupleInfeed) {
2454   const string original = R"(HloModule nontuple_infeed:
2455 ENTRY nontuple_infeed {
2456   token0 = token[] after-all()
2457   ROOT infeed = pred[] infeed(token0)
2458 })";
2459   ExpectHasSubstr(
2460       ParseAndReturnUnverifiedModule(original).status().error_message(),
2461       "infeed must have a non-empty tuple shape");
2462 }
2463 
2464 TEST(HloParserSingleOpTest, SingleOp) {
2465   const string text =
2466       "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
2467       "f32[2,4]{1,0} %x)";
2468   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2469   const HloComputation* computation = module->entry_computation();
2470   ASSERT_NE(computation, nullptr);
2471   EXPECT_THAT(computation->root_instruction(),
2472               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2473 }
2474 
2475 TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
2476   const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
2477   StatusOr<std::unique_ptr<HloModule>> module =
2478       ParseAndReturnUnverifiedModule(text);
2479   ASSERT_TRUE(!module.status().ok());
2480   LOG(INFO) << "Status: " << module.status();
2481   EXPECT_THAT(module.status().ToString(),
2482               ::testing::HasSubstr("expects '=' in instruction"));
2483 }
2484 
2485 TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
2486   const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
2487   StatusOr<std::unique_ptr<HloModule>> module =
2488       ParseAndReturnUnverifiedModule(text);
2489   ASSERT_TRUE(!module.status().ok());
2490   LOG(INFO) << "Status: " << module.status();
2491   EXPECT_THAT(module.status().ToString(),
2492               ::testing::HasSubstr("Operand had no shape in HLO text"));
2493 }
2494 
2495 TEST(HloParserSingleOpTest, SingleOpNoNames) {
2496   const string text =
2497       "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2498   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2499   const HloComputation* computation = module->entry_computation();
2500   ASSERT_NE(computation, nullptr);
2501   EXPECT_THAT(computation->root_instruction(),
2502               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2503 }
2504 
2505 TEST(HloParserSingleOpTest, CanonicalOp) {
2506   const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2507   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2508   const HloComputation* computation = module->entry_computation();
2509   ASSERT_NE(computation, nullptr);
2510   EXPECT_THAT(computation->root_instruction(),
2511               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2512   EXPECT_EQ(
2513       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2514       text);
2515 }
2516 
2517 TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
2518   const string text =
2519       R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
2520 {
2521   tmp_0 = f32[5,10]{1,0} parameter(0)
2522   tmp_1 = f32[20,10]{1,0} parameter(1)
2523   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2524   {
2525     tmp_0 = f32[5,10]{1,0} parameter(0)
2526     tmp_1 = f32[20,10]{1,0} parameter(1)
2527     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2528     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2529   }
2530 }, body=
2531 {
2532   tmp_0 = f32[5,10]{1,0} parameter(0)
2533   tmp_1 = f32[20,10]{1,0} parameter(1)
2534   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2535   {
2536     tmp_0 = f32[5,10]{1,0} parameter(0)
2537     tmp_1 = f32[20,10]{1,0} parameter(1)
2538     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2539     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2540   }
2541 })";
2542 
2543   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2544   const HloComputation* computation = module->entry_computation();
2545   ASSERT_NE(computation, nullptr);
2546   EXPECT_EQ(
2547       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2548       text);
2549 }
2550 
2551 TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
2552   const string text =
2553       R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
2554 {
2555   tmp_0 = f32[5,10]{1,0} parameter(0)
2556   ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
2557 },
2558 {
2559   tmp_0 = f32[5,10]{1,0} parameter(0)
2560   ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
2561 },
2562 {
2563   tmp_0 = f32[5,10]{1,0} parameter(0)
2564   ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
2565 }
2566 })";
2567 
2568   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2569   const HloComputation* computation = module->entry_computation();
2570   ASSERT_NE(computation, nullptr);
2571   EXPECT_EQ(
2572       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2573       text);
2574 }
2575 
2576 TEST(HloParserSingleOpTest, SingleOpWithNested) {
2577   const string text =
2578       R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
2579 {
2580   %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
2581   %param_1 = f32[2]{0} parameter(1)
2582   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
2583   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
2584 })";
2585 
2586   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2587   const HloComputation* computation = module->entry_computation();
2588   ASSERT_NE(computation, nullptr);
2589   EXPECT_THAT(computation->root_instruction(),
2590               GmockMatch(m::Op()
2591                              .WithOpcode(HloOpcode::kFusion)
2592                              .WithNumOperands(2)
2593                              .WithOperand(0, m::Parameter(0))
2594                              .WithOperand(1, m::Parameter(1))));
2595 }
2596 
2597 TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
2598   const string text =
2599       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2600 {
2601   result = f32[] add(f32[] x, f32[] y)
2602 })";
2603   auto status = ParseAndReturnUnverifiedModule(text).status();
2604   ASSERT_FALSE(status.ok());
2605   EXPECT_THAT(status.error_message(),
2606               ::testing::HasSubstr("does not exist: x"));
2607 }
2608 
2609 TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
2610   const string text =
2611       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2612 {
2613   f32[] add(f32[] x, f32[] y)
2614 })";
2615   auto status = ParseAndReturnUnverifiedModule(text).status();
2616   ASSERT_FALSE(status.ok());
2617   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2618 }
2619 
2620 TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
2621   const string text =
2622       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2623 {
2624   result = f32[] add(f32[], f32[])
2625 })";
2626   auto status = ParseAndReturnUnverifiedModule(text).status();
2627   ASSERT_FALSE(status.ok());
2628   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2629 }
2630 
2631 TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
2632   const string text =
2633       R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
2634   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2635   const HloComputation* computation = module->entry_computation();
2636   ASSERT_NE(computation, nullptr);
2637   EXPECT_THAT(computation->root_instruction(),
2638               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
2639   auto* convolution =
2640       Cast<HloConvolutionInstruction>(computation->root_instruction());
2641   EXPECT_EQ(convolution->feature_group_count(), 1);
2642 }
2643 
2644 TEST(HloParserSingleOpTest, MultipleOpsProducesError) {
2645   const string text = R"(
2646     param = f32[2,5,1,3] parameter(0)
2647     transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
2648   )";
2649   auto status = ParseAndReturnUnverifiedModule(text).status();
2650   ASSERT_FALSE(status.ok());
2651   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("Expected eof"));
2652 }
2653 
2654 TEST_F(HloParserTest, IsScheduledIsFalse) {
2655   const string text = R"(
2656 HloModule axpy_module, is_scheduled=false
2657 
2658 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2659   %alpha = f32[] parameter(0)
2660   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2661   %x = f32[2,4]{1,0} parameter(1)
2662   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2663   %y = f32[2,4]{1,0} parameter(2)
2664   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2665 }
2666 )";
2667   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2668   ASSERT_FALSE(module->has_schedule());
2669 }
2670 
2671 TEST_F(HloParserTest, IsScheduledNotPresent) {
2672   const string text = R"(
2673 HloModule axpy_module
2674 
2675 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2676   %alpha = f32[] parameter(0)
2677   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2678   %x = f32[2,4]{1,0} parameter(1)
2679   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2680   %y = f32[2,4]{1,0} parameter(2)
2681   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2682 }
2683 )";
2684   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2685   ASSERT_FALSE(module->has_schedule());
2686 }
2687 
2688 TEST_F(HloParserTest, IsScheduledIsTrue) {
2689   const string text = R"(
2690 HloModule axpy_module, is_scheduled=true
2691 
2692 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2693   %alpha = f32[] parameter(0)
2694   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2695   %x = f32[2,4]{1,0} parameter(1)
2696   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2697   %y = f32[2,4]{1,0} parameter(2)
2698   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2699 }
2700 )";
2701   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2702   ASSERT_TRUE(module->has_schedule());
2703   TF_ASSERT_OK(module->schedule().Verify());
2704   EXPECT_EQ(module->schedule().sequences().size(), 1);
2705   ASSERT_TRUE(
2706       module->schedule().is_computation_scheduled(module->entry_computation()));
2707   EXPECT_THAT(
2708       module->schedule().sequence(module->entry_computation()).instructions(),
2709       ::testing::ElementsAre(
2710           GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2711           GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
2712           GmockMatch(m::Parameter()), GmockMatch(m::Add())));
2713 }
2714 
2715 TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
2716   // As above but in with a different schedule order.
2717   const string text = R"(
2718 HloModule axpy_module, is_scheduled=true
2719 
2720 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2721   %alpha = f32[] parameter(0)
2722   %x = f32[2,4]{1,0} parameter(1)
2723   %y = f32[2,4]{1,0} parameter(2)
2724   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2725   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2726   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2727 }
2728 )";
2729   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2730   ASSERT_TRUE(module->has_schedule());
2731   TF_ASSERT_OK(module->schedule().Verify());
2732   EXPECT_EQ(module->schedule().sequences().size(), 1);
2733   ASSERT_TRUE(
2734       module->schedule().is_computation_scheduled(module->entry_computation()));
2735   EXPECT_THAT(
2736       module->schedule().sequence(module->entry_computation()).instructions(),
2737       ::testing::ElementsAre(
2738           GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
2739           GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2740           GmockMatch(m::Multiply()), GmockMatch(m::Add())));
2741 }
2742 
2743 TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
2744   const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
2745 
2746 ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2747   %p0 = f32[42,2,3]{0,1,2} parameter(0)
2748   %p1 = f32[123,4]{0,1} parameter(1)
2749   ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
2750 }
2751 
2752 )";
2753   ExpectHasSubstr(
2754       ParseAndReturnUnverifiedModule(original).status().error_message(),
2755       "Expected 2 operand layout constraints, 1 given");
2756 }
2757 
2758 TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
2759   const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
2760 
2761 ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2762   %p0 = f32[42,2,3]{0,1,2} parameter(0)
2763   %p1 = f32[123,4]{0,1} parameter(1)
2764   ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
2765 }
2766 
2767 )";
2768   ExpectHasSubstr(
2769       ParseAndReturnUnverifiedModule(original).status().error_message(),
2770       "operand 1 is not compatible with operand shape");
2771 }
2772 
2773 TEST_F(HloParserTest, AllowShapeWhitespace) {
2774   const string text = R"(
2775 HloModule module
2776 
2777 ENTRY entry {
2778   ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
2779 }
2780 )";
2781   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2782 }
2783 
2784 TEST_F(HloParserTest, ShapeMismatchInOperand) {
2785   const string text = R"(
2786 HloModule foobar
2787 
2788 ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
2789   %p = f32[2,2] parameter(0)
2790   %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
2791   ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
2792 }
2793 )";
2794 
2795   ExpectHasSubstr(ParseAndReturnUnverifiedModule(text).status().error_message(),
2796                   "The declared operand shape f32[2,5]{1,0} is not compatible"
2797                   " with the shape of the operand instruction f32[2,2]{1,0}.");
2798 }
2799 
2800 TEST_F(HloParserTest, ParseShapeStringR2F32) {
2801   string shape_string = "f32[123,456]";
2802   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2803   Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
2804   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2805       << "expected: " << ShapeUtil::HumanString(expected)
2806       << "actual:   " << ShapeUtil::HumanString(actual);
2807 }
2808 
2809 TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
2810   string shape_string = "(f32[1572864],s8[5120,1024])";
2811   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2812   Shape expected =
2813       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
2814                                  ShapeUtil::MakeShape(S8, {5120, 1024})});
2815   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2816       << "expected: " << ShapeUtil::HumanString(expected)
2817       << "actual:   " << ShapeUtil::HumanString(actual);
2818 }
2819 
2820 TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
2821   string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
2822   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2823   Shape expected = ShapeUtil::MakeTupleShape({
2824       ShapeUtil::MakeShape(F32, {1}),
2825       ShapeUtil::MakeTupleShape(
2826           {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
2827       ShapeUtil::MakeOpaqueShape(),
2828       ShapeUtil::MakeShape(F32, {3}),
2829   });
2830   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2831       << "expected: " << ShapeUtil::HumanString(expected)
2832       << "actual:   " << ShapeUtil::HumanString(actual);
2833 }
2834 
2835 TEST_F(HloParserTest, ParseShapeStringWithLayout) {
2836   string shape_string = "f32[123,456]{0,1}";
2837   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2838   Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
2839   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2840       << "expected: " << ShapeUtil::HumanString(expected)
2841       << "actual:   " << ShapeUtil::HumanString(actual);
2842 }
2843 
2844 TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
2845   // One tile.
2846   string shape_string = "f32[123,456]{0,1:T(2,128)}";
2847   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2848   Shape expected =
2849       ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})});
2850   EXPECT_EQ(expected, actual)
2851       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2852       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2853 
2854   // Tile with negative dimension size for combining dimensions.
2855   shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
2856   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2857   expected =
2858       ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2},
2859                                      {Tile({2, Tile::kCombineDimension, 128})});
2860   EXPECT_EQ(expected, actual)
2861       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2862       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2863 
2864   // Two tiles.
2865   shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
2866   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2867   expected = ShapeUtil::MakeShapeWithLayout(
2868       BF16, {123, 456, 789}, {2, 1, 0},
2869       {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
2870   EXPECT_EQ(expected, actual)
2871       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2872       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2873 
2874   // Tile with element size in bits.
2875   shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
2876   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2877   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
2878                                             {Tile({2, 128})}, 1);
2879   EXPECT_EQ(expected, actual)
2880       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2881       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2882 
2883   // Element size in bits without tile.
2884   shape_string = "pred[123,456]{1,0:E(1)}";
2885   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2886   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1);
2887   EXPECT_EQ(expected, actual)
2888       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2889       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2890 
2891   // Wrong minor_to_major.
2892   shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
2893   auto result = ParseShape(shape_string);
2894   ExpectHasSubstr(result.status().error_message(),
2895                   "Dimensions size is 3, but minor to major size is 1.");
2896 }
2897 
2898 TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
2899   // Tile, element size, and memory space.
2900   string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
2901   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2902   Shape expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
2903                                                   {Tile({2, 128})}, 1, 3);
2904   EXPECT_EQ(expected, actual)
2905       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2906       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2907 
2908   // Element size and memory space.
2909   shape_string = "pred[123,456]{1,0:E(1)S(3)}";
2910   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2911   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1, 3);
2912   EXPECT_EQ(expected, actual)
2913       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2914       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2915 
2916   // Memory space only.
2917   shape_string = "pred[123,456]{1,0:S(3)}";
2918   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2919   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 0, 3);
2920   EXPECT_EQ(expected, actual)
2921       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2922       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2923 }
2924 
2925 TEST_F(HloParserTest, ParseOpaqueType) {
2926   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
2927   Shape expected = ShapeUtil::MakeOpaqueShape();
2928   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2929       << "expected: " << ShapeUtil::HumanString(expected)
2930       << "actual:   " << ShapeUtil::HumanString(actual);
2931 }
2932 
2933 TEST_F(HloParserTest, ParseTokenType) {
2934   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
2935   Shape expected = ShapeUtil::MakeTokenShape();
2936   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2937       << "expected: " << ShapeUtil::HumanString(expected)
2938       << "actual:   " << ShapeUtil::HumanString(actual);
2939 }
2940 
2941 TEST_F(HloParserTest, ParseInvalidShapeString) {
2942   string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
2943                             "f32[123,456]dense{foo}"};
2944   for (const string& shape_string : shape_strings) {
2945     StatusOr<Shape> result = ParseShape(shape_string);
2946     ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
2947   }
2948 }
2949 
2950 TEST_F(HloParserTest, ParseDynamicArray) {
2951   string shape_string = "f32[123,<=456]";
2952   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2953   Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
2954   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2955       << "expected: " << ShapeUtil::HumanString(expected)
2956       << "actual:   " << ShapeUtil::HumanString(actual);
2957 }
2958 
2959 TEST_F(HloParserTest, ParseDynamicTuple) {
2960   string shape_string = "(f32[42], u32[<=123,<=456])";
2961   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2962   Shape expected = ShapeUtil::MakeTupleShape(
2963       {ShapeUtil::MakeShape(F32, {42}),
2964        ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
2965   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2966       << "expected: " << ShapeUtil::HumanString(expected)
2967       << "actual:   " << ShapeUtil::HumanString(actual);
2968 }
2969 
2970 TEST_F(HloParserTest, NegativeParameterNumber) {
2971   const string hlo_string = "par0 = f32[3,5] parameter(-1)";
2972   auto result = ParseAndReturnUnverifiedModule(hlo_string);
2973   ASSERT_FALSE(result.status().ok());
2974   EXPECT_THAT(result.status().error_message(),
2975               ::testing::HasSubstr("parameter number must be >= 0"));
2976 }
2977 
2978 TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
2979   const string hlo_string =
2980       "par0 = (f32[3,5], f32[]) parameter(0), "
2981       "parameter_replication={true,false,true}";
2982   auto result = ParseAndReturnUnverifiedModule(hlo_string);
2983   ASSERT_FALSE(result.status().ok());
2984   EXPECT_THAT(result.status().error_message(),
2985               ::testing::HasSubstr("parameter has 2 leaf buffers, but "
2986                                    "parameter_replication has 3 elements"));
2987 }
2988 
2989 TEST_F(HloParserTest, CheckIndexedConditionalDimension) {
2990   const char* const hlo_string = R"(
2991   HloModule Module
2992 
2993   branch0 {
2994     tparam = f32[4] parameter(0)
2995     ROOT tgte1 = f32[4] ceil(tparam)
2996   }
2997 
2998   branch1 {
2999     fparam = f32[4] parameter(0)
3000     ROOT fgte1 = f32[4] floor(fparam)
3001   }
3002 
3003   ENTRY entry {
3004     p0 = f32[4] parameter(0)
3005     b0 = s32[2] parameter(1)
3006     ROOT conditional = f32[4] conditional(b0, p0, p0),
3007       branch_computations={branch0, branch1}
3008   }
3009   )";
3010   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3011   EXPECT_NE(Status::OK(), result.status());
3012   EXPECT_THAT(result.status().error_message(),
3013               ::testing::HasSubstr("The first operand must be a scalar"));
3014 }
3015 
3016 TEST_F(HloParserTest, CheckIndexedConditionalElementType) {
3017   const char* const hlo_string = R"(
3018   HloModule Module
3019 
3020   branch0 {
3021     tparam = f32[4] parameter(0)
3022     ROOT tgte1 = f32[4] ceil(tparam)
3023   }
3024 
3025   branch1 {
3026     fparam = f32[4] parameter(0)
3027     ROOT fgte1 = f32[4] floor(fparam)
3028   }
3029 
3030   ENTRY entry {
3031     p0 = f32[4] parameter(0)
3032     b0 = f32[] parameter(1)
3033     ROOT conditional = f32[4] conditional(b0, p0, p0),
3034       branch_computations={branch0, branch1}
3035   }
3036   )";
3037   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3038   EXPECT_NE(Status::OK(), result.status());
3039   EXPECT_THAT(result.status().error_message(),
3040               ::testing::HasSubstr(
3041                   "The first operand must be a scalar of PRED or S32"));
3042 }
3043 
3044 TEST_F(HloParserTest,
3045        CheckPredicatedConditionalRequiresTrueAndFalseComputation) {
3046   const char* const hlo_string = R"(
3047   HloModule Module
3048 
3049   branch0 {
3050     tparam = f32[4] parameter(0)
3051     ROOT tgte1 = f32[4] ceil(tparam)
3052   }
3053 
3054   branch1 {
3055     fparam = f32[4] parameter(0)
3056     ROOT fgte1 = f32[4] floor(fparam)
3057   }
3058 
3059   ENTRY entry {
3060     p0 = f32[4] parameter(0)
3061     b0 = pred[] parameter(1)
3062     ROOT conditional = f32[4] conditional(b0, p0, p0),
3063       branch_computations={branch0, branch1}
3064   }
3065   )";
3066   auto result = ParseAndReturnUnverifiedModule(hlo_string);
3067   EXPECT_NE(Status::OK(), result.status());
3068   EXPECT_THAT(
3069       result.status().error_message(),
3070       ::testing::HasSubstr("unexpected attribute \"branch_computations\""));
3071 }
3072 
3073 }  // namespace
3074 }  // namespace xla
3075