1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
27 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
31 #include "tensorflow/compiler/xla/window_util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test.h"
35
36 namespace xla {
37 namespace {
38
39 namespace m = ::xla::match;
40 using absl::string_view;
41
42 struct TestData {
43 string test_name;
44 string module_string;
45 bool enable_verification = true;
46 };
47
TestDataToString(const::testing::TestParamInfo<TestData> & data)48 string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
49 return data.param.test_name;
50 }
51
52 // For each string below, we check that:
53 // - we parse it to an HloModule successfully, and
54 // - the stringification of the resulting HloModule is equal to our original
55 // string.
CreateTestCases()56 std::vector<TestData> CreateTestCases() {
57 // clang-format off
58 return std::vector<TestData>({
59 // ax + y
60 {
61 "AxpyParam",
62 R"(HloModule axpy_module
63
64 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
65 %alpha = f32[] parameter(0)
66 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
67 %x = f32[2,4]{1,0} parameter(1)
68 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
69 %y = f32[2,4]{1,0} parameter(2)
70 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
71 }
72
73 )"
74 },
75 // parameter replication
76 {
77 "ParamReplication",
78 R"(HloModule param_replication_module
79
80 ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
81 %a = f32[] parameter(0), parameter_replication={true}
82 %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
83 ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
84 }
85
86 )"
87 },
88 // pred constant
89 {
90 "ConstantPred",
91 R"(HloModule constant_pred_module
92
93 ENTRY %constant_pred () -> pred[] {
94 ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
95 }
96
97 )"
98 },
99 // pred array constant
100 {
101 "ConstantPredArray",
102 R"(HloModule module
103
104 ENTRY %constant_pred_array () -> pred[2,3] {
105 ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
106 }
107
108 )"
109 },
110
111 // s32 constant
112 {
113 "ConstantS32",
114 R"(HloModule constant_s32_module
115
116 ENTRY %constant_s32 () -> s32[] {
117 ROOT %constant = s32[] constant(-42)
118 }
119
120 )"
121 },
122 // f32 constant, but the value is not a decimal and there is a backend
123 // configuration
124 {
125 "ConstantF32",
126 R"(HloModule ConstantF32_module
127
128 ENTRY %ConstantF32.v4 () -> f32[] {
129 ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
130 }
131
132 )"
133 },
134 // f32 constant, rank 1 empty array.
135 {
136 "ConstantF32R1Empty",
137 R"(HloModule ConstantF32Empty_module
138
139 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
140 ROOT %constant = f32[0]{0} constant({})
141 }
142
143 )"
144 },
145 // f32 constant, rank 4 empty array.
146 {
147 "ConstantF32R4Empty",
148 R"(HloModule ConstantF32R4Empty_module
149
150 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
151 ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
152 }
153
154 )"
155 },
156 // constant 4D
157 {
158 "Constant4D",
159 R"(HloModule Small_3x2x1x1_module
160
161 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
162 ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
163 }
164
165 )"
166 },
167 // non-finite constants: nan, inf, -inf
168 {
169 "ConstantNonFinite",
170 R"(HloModule IsFiniteR1F32s_module
171
172 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
173 %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
174 ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
175 }
176
177 )"
178 },
179 // constant f16
180 {
181 "ConstantF16",
182 R"(HloModule ConstantF16_module
183
184 ENTRY %ConstantF16.v4 () -> f16[] {
185 ROOT %constant = f16[] constant(500)
186 }
187
188 )"
189 },
190 // bf16
191 {
192 "BF16",
193 R"(HloModule BF16
194
195 ENTRY %BF16.v4 () -> bf16[] {
196 ROOT %constant = bf16[] constant(500)
197 }
198
199 )"
200 },
201 // constant + constant
202 {
203 "AddConstants",
204 R"(HloModule add_constants_module
205
206 ENTRY %add_constants () -> f32[] {
207 %constant = f32[] constant(3.14)
208 ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
209 }
210
211 )"
212 },
213 // tuple constant
214 {
215 "TupleConstant",
216 R"(HloModule TupleConstant_module
217
218 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
219 ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
220 }
221
222 )"
223 },
224 // v1 > v2 ? v1 : v2
225 {
226 "SelectR1F32",
227 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
228
229 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
230 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
231 %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
232 %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
233 ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
234 }
235
236 )"
237 },
238 // empty tuple
239 {
240 "EmptyTupleCreate",
241 R"(HloModule EmptyTupleCreate_module
242
243 ENTRY %EmptyTupleCreate.v1 () -> () {
244 ROOT %tuple = () tuple()
245 }
246
247 )"
248 },
249 // tuple
250 {
251 "TupleCreate",
252 R"(HloModule TupleCreate_module
253
254 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
255 %v1 = f32[] parameter(0)
256 %v2 = f32[3]{0} parameter(1)
257 %v3 = f32[2,3]{1,0} parameter(2)
258 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
259 }
260
261 )"
262 },
263 {
264 "ShardedTupleCreate",
265 R"(HloModule ShardedTupleCreate_module
266
267 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
268 %v1 = f32[] parameter(0)
269 %v2 = f32[3]{0} parameter(1)
270 %v3 = f32[2,3]{1,0} parameter(2)
271 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
272 }
273
274 )"
275 },
276 {
277 "DomainParsing",
278 R"(HloModule DomainParsing_module
279
280 ENTRY %DomainParsing (v1: f32[]) -> f32[] {
281 %v1 = f32[] parameter(0)
282 ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
283 }
284
285 )"
286 },
287 // int32 result = 0;
288 // while (result < 5) { result = result + 1; }
289 {
290 "WhileWithScalarS32Result",
291 R"(HloModule WhileWithScalarS32Result_module
292
293 %body.v3 (prev.1: s32[]) -> s32[] {
294 %constant = s32[] constant(1)
295 %prev.1 = s32[] parameter(0)
296 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
297 }
298
299 %condition.v3 (prev.2: s32[]) -> pred[] {
300 %constant.1 = s32[] constant(5)
301 %prev.2 = s32[] parameter(0)
302 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
303 }
304
305 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
306 %constant.2 = s32[] constant(0)
307 ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
308 }
309
310 )"
311 },
312 // copy-start and copy-done
313 {
314 "CopyStartAndCopyDone",
315
316 R"(HloModule CopyStartAndCopyDone_module
317
318 ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
319 %v1 = f32[] parameter(0)
320 %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
321 %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
322 %v2 = f32[2,3]{1,0:S(1)} parameter(1)
323 %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
324 %copy-done.2 = f32[2,3]{1,0:S(2)} copy-done((f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) %copy-start.2)
325 ROOT %tuple = (f32[], f32[2,3]{1,0:S(2)}) tuple(f32[] %copy-done.1, f32[2,3]{1,0:S(2)} %copy-done.2)
326 }
327
328 )"
329 },
330 // send and recv
331 {
332 "SendRecv",
333 R"(HloModule TwoSendRecvBothWayRecvFist_module
334
335 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
336 %token0 = token[] after-all()
337 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
338 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
339 %constant = f32[] constant(2.1), sharding={maximal device=0}
340 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
341 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
342 }
343
344 )"
345 },
346 {
347 "SendRecvWithHostTransfer",
348 R"(HloModule HostTransferSendRecv_module
349
350 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
351 %token0 = token[] after-all()
352 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
353 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
354 %constant = f32[] constant(2.1), sharding={maximal device=0}
355 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
356 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
357 }
358
359 )"
360 },
361 // get-tuple-element
362 {
363 "GetTupleElement",
364 R"(HloModule GetTupleElement_module
365
366 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
367 %constant = f32[3]{0} constant({1, 2, 3})
368 %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
369 %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
370 ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
371 }
372
373 )"
374 },
375 // call
376 {
377 "Call",
378 R"(HloModule CallR0F32IdentityScalar_module
379
380 %Identity.v1 (x: f32[]) -> f32[] {
381 ROOT %x = f32[] parameter(0)
382 }
383
384 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
385 %constant = f32[] constant(42)
386 ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
387 }
388
389 )"
390 },
391 // CustomCall with backend_config.
392 {
393 "CustomCallWithOpaque",
394 R"(HloModule custom_call
395
396 ENTRY %CustomCall () -> f32[1,2,3] {
397 %constant = f32[1]{0} constant({12345})
398 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar", backend_config="this string is opaque"
399 }
400
401 )"
402 },
403 // reduce window
404 {
405 "ReduceWindow",
406 R"(HloModule R4UnitWindow_module
407
408 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
409 %lhs = f32[] parameter(0)
410 %rhs = f32[] parameter(1)
411 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
412 }
413
414 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
415 %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
416 %constant = f32[] constant(0)
417 ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
418 }
419
420 )"
421 },
422 // reduce window on scalar
423 {
424 "ReduceWindowScalar",
425 R"(HloModule reduce_window_scalar
426
427 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
428 %lhs = f32[] parameter(0)
429 %rhs = f32[] parameter(1)
430 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
431 }
432
433 ENTRY %R4UnitWindowScalar () -> f32[] {
434 %constant = f32[] constant(42)
435 %constant.1 = f32[] constant(1)
436 ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
437 }
438
439 )"
440 },
441 // convolution
442 {
443 "Convolution",
444 R"(HloModule Convolve1D1Window_0_module
445
446 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
447 %input = f32[1,2,1]{2,1,0} parameter(0)
448 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
449 %filter = f32[1,1,1]{2,1,0} parameter(1)
450 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
451 }
452
453 )"
454 },
455 // convolution rank 2
456 {
457 "ConvolutionR2",
458 R"(HloModule ConvolveR2_module
459
460 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[2,2]) -> f32[1,2] {
461 %input = f32[1,2]{1,0} parameter(0)
462 %filter = f32[2,2]{1,0} parameter(1)
463 ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[2,2]{1,0} %filter), dim_labels=bf_io->bf
464 }
465
466 )"
467 },
468 // convolution backward
469 {
470 "ConvolutionBackward",
471 R"(HloModule ConvolveBackward_module
472
473 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
474 %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
475 %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
476 ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
477 }
478
479 )"
480 },
481 // reverse(constant)
482 {
483 "Reverse4D",
484 R"(HloModule Reverse4DFloatArrayOnDim01_module
485
486 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
487 %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
488 ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
489 }
490
491 )"
492 },
493 // concat
494 {
495 "Concat",
496 R"(HloModule Concat2x3With2x5_module
497
498 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
499 %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
500 %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
501 ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
502 }
503
504 )"
505 },
506 // select and scatter
507 {
508 "SelectAndScatter",
509 R"(HloModule R4F32OverlapSmall_module
510
511 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
512 %lhs = f32[] parameter(0)
513 %rhs = f32[] parameter(1)
514 ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
515 }
516
517 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
518 %lhs.1 = f32[] parameter(0)
519 %rhs.1 = f32[] parameter(1)
520 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
521 }
522
523 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
524 %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
525 %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
526 %constant.2 = f32[] constant(0)
527 ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
528 }
529
530 )"
531 },
532 // select and scatter on scalar
533 {
534 "SelectAndScatterScalar",
535 R"(HloModule select_and_scatter_scalar
536
537 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
538 %lhs = f32[] parameter(0)
539 %rhs = f32[] parameter(1)
540 ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
541 }
542
543 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
544 %lhs.1 = f32[] parameter(0)
545 %rhs.1 = f32[] parameter(1)
546 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
547 }
548
549 ENTRY %SelectAndScatterScalar () -> f32[] {
550 %constant = f32[] constant(42)
551 %constant.1 = f32[] constant(1)
552 %constant.2 = f32[] constant(2)
553 ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
554 }
555
556 )"
557 },
558 // slice
559 {
560 "Slice",
561 R"(HloModule slice_module
562
563 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
564 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
565 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
566 }
567
568 )"
569 },
570 // slice, no stride
571 {
572 "SliceNoStride",
573 R"(HloModule Slice3x3x3_To_1x3x3_F32_module
574
575 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
576 %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
577 ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
578 }
579
580 )"
581 },
582 // slice R0
583 {
584 "SliceR0",
585 R"(HloModule SliceR0_module
586
587 ENTRY %SliceR0.v2 () -> s32[] {
588 %constant = s32[] constant(1)
589 ROOT %slice = s32[] slice(s32[] %constant), slice={}
590 }
591
592 )"
593 },
594 // transpose
595 {
596 "Transpose",
597 R"(HloModule Transpose_module
598
599 ENTRY %Transpose.v2 () -> s32[1,2,3] {
600 %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
601 ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
602 }
603
604 )"
605 },
606 {
607 "TransposeC128",
608 R"(HloModule TransposeC128_module
609
610 ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
611 %input = c128[1,2,3]{2,1,0} parameter(0)
612 ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
613 }
614
615 )"
616 },
617 // Triangular solve
618 {
619 "TriangularSolve",
620 R"(HloModule TriangularSolve_module
621
622 ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
623 %a.1 = f32[4,4]{1,0} parameter(0)
624 %b.2 = f32[3,4]{1,0} parameter(1)
625 ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
626 }
627
628 )"
629 },
630 // Dynamic slice
631 {
632 "DynamicSlice",
633 R"(HloModule DynamicSlice_module
634
635 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
636 %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
637 %constant = s32[1]{0} constant({0})
638 %start_index = s32[1]{0} parameter(1)
639 %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
640 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
641 }
642
643 )"
644 },
645 // Dynamic slice with scalar indices
646 {
647 "DynamicSliceScalarIndices",
648 R"(HloModule DynamicSlice_module
649
650 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
651 %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
652 %constant = s32[] constant(0)
653 %start_index = s32[] parameter(1)
654 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
655 }
656
657 )"
658 },
659 // Dynamic update slice
660 {
661 "DynamicUpdateSlice",
662 R"(HloModule DynamicSlice_module
663
664 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
665 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
666 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
667 %start_indices = s32[4]{0} parameter(2)
668 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
669 }
670
671 )"
672 },
673 // Dynamic update slice with scalar indices
674 {
675 "DynamicUpdateSliceScalarIndex",
676 R"(HloModule DynamicUpdateSlice_module
677
678 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
679 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
680 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
681 %start_index.0 = s32[] parameter(2)
682 %start_index.1 = s32[] parameter(3)
683 %start_index.2 = s32[] parameter(4)
684 %start_index.3 = s32[] parameter(5)
685 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
686 }
687
688 )"
689 },
690 // batch norm training
691 {
692 "BatchNormTraining",
693 R"(HloModule BasicTraining_module
694
695 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
696 %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
697 %constant.1 = f32[2]{0} constant({2, 3})
698 %constant.2 = f32[2]{0} constant({1, 2})
699 ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
700 }
701
702 )"
703 },
704 // batch norm inference
705 {
706 "BatchNormInference",
707 R"(HloModule BatchNormInference_module
708
709 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
710 %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
711 %offset = f32[2]{0} parameter(1)
712 %scale = f32[2]{0} parameter(2)
713 %mean = f32[2]{0} parameter(3)
714 %variance = f32[2]{0} parameter(4)
715 ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
716 }
717
718 )"
719 },
720 // batch norm grad
721 {
722 "BatchNormGrad",
723 R"(HloModule BatchNormGrad_module
724
725 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
726 %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
727 %scale = f32[2]{0} parameter(1)
728 %mean = f32[2]{0} parameter(2)
729 %variance = f32[2]{0} parameter(3)
730 %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
731 ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
732 }
733
734 )"
735 },
736 // fft
737 {
738 "Fft",
739 R"(HloModule Fft_module
740
741 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
742 %input = c64[8,32]{1,0} parameter(0)
743 ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
744 }
745
746 )"
747 },
748 // ifft
749 {
750 "Ifft2d",
751 R"(HloModule Ifft2d_module
752
753 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
754 %input = c64[5,8,32]{2,1,0} parameter(0)
755 ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
756 }
757
758 )"
759 },
760 // rfft2d
761 {
762 "Rfft2d",
763 R"(HloModule Rfft2d_module
764
765 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
766 %input = f32[5,64,32]{2,1,0} parameter(0)
767 ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
768 }
769
770 )"
771 },
772 // irfft3d
773 {
774 "Irfft3d",
775 R"(HloModule Irfft3d_module
776
777 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
778 %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
779 ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
780 }
781
782 )"
783 },
784 // pad
785 {
786 "Pad",
787 R"(HloModule Pad1DS3Array_module
788
789 ENTRY %Pad1DS3Array.v3 () -> f32[7] {
790 %constant = f32[3]{0} constant({1, 2, 3})
791 %constant.1 = f32[] constant(0.1)
792 ROOT %pad = f32[7]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
793 }
794
795 )"
796 },
797 // pad has interior
798 {
799 "PadHasInterior",
800 R"(HloModule PadHasInterior_module
801
802 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
803 %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
804 %constant = f32[] constant(-5.123)
805 ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
806 }
807
808 )"
809 },
810 // Negative padding
811 {
812 "PadHasNegativePadding",
813 R"(HloModule PadHasNegativePadding_module
814
815 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,35] {
816 %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
817 %constant = f32[] constant(-5.123)
818 ROOT %pad = f32[1,15,6,3,35]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
819 }
820
821 )"
822 },
823 // fusion
824 {
825 "Fusion",
826 R"(HloModule fusion_module
827
828 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
829 %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
830 %constant.1.param_1 = f32[2]{0} parameter(1)
831 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
832 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
833 }
834
835 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
836 %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
837 %constant.1 = f32[2]{0} constant({3.14, 4.25})
838 ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
839 }
840
841 )"
842 },
843 {
844 "Gather",
845 R"(HloModule StringifyGather
846
847 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
848 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
849 %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
850 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
851 }
852
853 )"
854 },
855 {
856 "SortedGather",
857 R"(HloModule StringifyGather
858
859 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
860 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
861 %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
862 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}, indices_are_sorted=true
863 }
864
865 )"
866 },
867 {
868 "Scatter",
869 R"(HloModule StringifyScatter
870
871 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
872 %lhs = f32[] parameter(0)
873 %rhs = f32[] parameter(1)
874 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
875 }
876
877 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
878 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
879 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
880 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
881 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
882 }
883
884 )"
885 },
886 {
887 "SortedScatter",
888 R"(HloModule StringifySortedScatter
889
890 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
891 %lhs = f32[] parameter(0)
892 %rhs = f32[] parameter(1)
893 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
894 }
895
896 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
897 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
898 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
899 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
900 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, indices_are_sorted=true, to_apply=%add_F32.v3
901 }
902
903 )"
904 },
905 {
906 "UniqueIndicesScatter",
907 R"(HloModule StringifyUniqueIndicesScatter
908
909 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
910 %lhs = f32[] parameter(0)
911 %rhs = f32[] parameter(1)
912 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
913 }
914
915 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
916 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
917 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
918 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
919 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, unique_indices=true, to_apply=%add_F32.v3
920 }
921
922 )"
923 },
924 {
925 "ConstantUnsignedNoUnderflow",
926 R"(HloModule ConstantUnsignedNoUnderflow_module
927
928 ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
929 ROOT %constant = u64[] constant(1)
930 }
931
932 )"
933 },
934
935 {
936 "ConstantUnsignedNoOverflow",
937 R"(HloModule ConstantUnsignedNoOverflow_module
938
939 ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
940 ROOT %constant = u64[] constant(9223372036854775807)
941 }
942
943 )"
944 },
945 // CustomCallWithLayoutConstraints
946 {
947 "CustomCallWithLayoutConstraints",
948 R"(HloModule CustomCallWithLayoutConstraints
949
950 ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
951 %p0 = f32[42,2,3]{0,1,2} parameter(0)
952 %p1 = f32[123,4]{0,1} parameter(1)
953 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
954 }
955
956 )"
957 },
958 // CustomCallWithLayoutConstraintsNoOperands
959 {
960 "CustomCallWithLayoutConstraintsNoOperands",
961 R"(HloModule CustomCallWithLayoutConstraintsNoOperands
962
963 ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
964 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
965 }
966
967 )"
968 },
969 // CustomCallWithLayoutConstraintsTupleShapes
970 {
971 "CustomCallWithLayoutConstraintsTupleShapes",
972 R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
973
974 ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
975 %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
976 %p1 = f32[123,4]{0,1} parameter(1)
977 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
978 }
979
980 )"
981 },
982 // CustomCallWithHasSideEffect
983 {
984 "CustomCallWithHasSideEffect",
985 R"(HloModule CustomCallWithHasSideEffect
986
987 ENTRY %CustomCallWithHasSideEffect (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
988 %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
989 %p1 = f32[123,4]{0,1} parameter(1)
990 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", custom_call_has_side_effect=true
991 }
992
993 )"
994 },
995 // Parse c64 literal
996 {
997 "ParseC64Literal",
998 R"(HloModule ParseC64Literal
999
1000 ENTRY %ParseC64Literal () -> c64[2] {
1001 ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
1002 }
1003
1004 )"
1005 },
1006 // Parse c128 literal
1007 {
1008 "ParseC128Literal",
1009 R"(HloModule ParseC128Literal
1010
1011 ENTRY %ParseC128Literal () -> c128[2] {
1012 ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
1013 }
1014
1015 )"
1016 },
1017 // Indexed Conditional
1018 {
1019 "IndexedConditional",
1020 R"(HloModule indexed_conditional
1021
1022 %Negate (x: f32[]) -> f32[] {
1023 %x = f32[] parameter(0)
1024 ROOT %negate = f32[] negate(f32[] %x)
1025 }
1026
1027 %Identity (y: f32[]) -> f32[] {
1028 %y = f32[] parameter(0)
1029 ROOT %copy = f32[] copy(f32[] %y)
1030 }
1031
1032 %Floor (z: f32[]) -> f32[] {
1033 %z = f32[] parameter(0)
1034 ROOT %floor = f32[] floor(f32[] %z)
1035 }
1036
1037 ENTRY %Parameters1.v4 () -> f32[] {
1038 %constant = s32[] constant(1)
1039 %constant.1 = f32[] constant(56)
1040 %constant.2 = f32[] constant(12)
1041 %constant.3 = f32[] constant(13)
1042 ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
1043 }
1044
1045 )"
1046 },
1047 // rng-get-and-update-state
1048 {
1049 "RngGetAndUpdateState",
1050 R"(HloModule rng_get_and_update_state
1051
1052 ENTRY %RngGetAndUpdateState () -> u64[2] {
1053 ROOT %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=4096
1054 }
1055
1056 )"
1057 },
1058 });
1059 // clang-format on
1060 }
1061
1062 std::vector<TestData> CreateShortTestCases() {
1063 // clang-format off
1064 return std::vector<TestData>({
1065 // map
1066 {
1067 "Map",
1068 R"(HloModule MapBinaryAdder_module
1069
1070 add_F32.v3 {
1071 lhs = f32[] parameter(0)
1072 rhs = f32[] parameter(1)
1073 ROOT add = f32[] add(lhs, rhs)
1074 }
1075
1076 ENTRY MapBinaryAdder.v3 {
1077 param0 = f32[4]{0} parameter(0)
1078 param1 = f32[4]{0} parameter(1)
1079 ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
1080 }
1081
1082 )"
1083 },
1084 // reduce
1085 {
1086 "Reduce",
1087 R"(HloModule ReduceR3ToR2_module
1088
1089 add_F32.v3 {
1090 lhs = f32[] parameter(0)
1091 rhs = f32[] parameter(1)
1092 ROOT add = f32[] add(lhs, rhs)
1093 }
1094
1095 ENTRY ReduceR3ToR2.v3 {
1096 input = f32[8,16,256]{2,1,0} parameter(0)
1097 constant = f32[] constant(0)
1098 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
1099 }
1100
1101 )"
1102 },
1103 // tuple reduce
1104 {
1105 "TupleReduce",
1106 R"(HloModule TupleReduce
1107
1108 max_argmax {
1109 value = f32[] parameter(2)
1110 prev_max = f32[] parameter(0)
1111 is_next_larger = pred[] compare(value, prev_max), direction=GE
1112 max = f32[] select(is_next_larger, value, prev_max)
1113 index = s32[] parameter(3)
1114 prev_argmax = s32[] parameter(1)
1115 argmax = s32[] select(is_next_larger, index, prev_argmax)
1116 ROOT pair = (f32[], s32[]) tuple(max, argmax)
1117 }
1118
1119 ENTRY reduce_entry {
1120 values = f32[1024]{0} parameter(0)
1121 indices = s32[1024]{0} parameter(1)
1122 init_value = f32[] constant(-inf)
1123 init_index = s32[] constant(-1)
1124 ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
1125 }
1126
1127 )"
1128 },
1129 // infeed/outfeed
1130 {
1131 "InfeedOutfeed",
1132 R"(HloModule outfeed_module
1133
1134 ENTRY InfeedToOutfeed {
1135 token0 = token[] after-all()
1136 infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1137 infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
1138 outfeed = token[] outfeed(infeed.data, token0)
1139 ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1140 infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
1141 infeed.1.token = token[] get-tuple-element(infeed.1), index=1
1142 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
1143 }
1144
1145 )"
1146 },
1147 // Rng
1148 {
1149 "Rng",
1150 R"(HloModule rng_module
1151
1152 ENTRY Rng {
1153 constant = f32[] constant(0)
1154 constant.1 = f32[] constant(1)
1155 ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
1156 }
1157
1158 )"
1159 },
1160 // Reduce precision
1161 {
1162 "ReducePrecision",
1163 R"(HloModule reduce_precision
1164
1165 ENTRY ReducePrecision {
1166 constant = f32[1]{0} constant({3.14159})
1167 ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
1168 }
1169
1170 )"
1171 },
1172 // Sort (Key)
1173 {
1174 "SortKey",
1175 R"(HloModule sort
1176
1177 compare {
1178 p.0.lhs = f32[] parameter(0)
1179 p.0.rhs = f32[] parameter(1)
1180 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1181 }
1182
1183 ENTRY Sort {
1184 x = f32[1024]{0} parameter(0)
1185 ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
1186 }
1187
1188 )"
1189 },
1190 // Sort (Key, Value)
1191 {
1192 "SortKeyValue",
1193 R"(HloModule sort
1194
1195 compare {
1196 p.1.lhs = s32[] parameter(2)
1197 p.1.rhs = s32[] parameter(3)
1198 p.0.lhs = f32[] parameter(0)
1199 p.0.rhs = f32[] parameter(1)
1200 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1201 }
1202
1203 ENTRY Sort {
1204 keys = f32[1024]{0} parameter(0)
1205 values = s32[1024]{0} parameter(1)
1206 ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1207 }
1208
1209 )"
1210 },
1211 // R2 Sort (Key)
1212 {
1213 "SortKeyR2",
1214 R"(HloModule sort
1215
1216 compare {
1217 p.0.lhs = f32[] parameter(0)
1218 p.0.rhs = f32[] parameter(1)
1219 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1220 }
1221
1222 ENTRY Sort {
1223 x = f32[1024,16]{0,1} parameter(0)
1224 ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
1225 }
1226
1227 )"
1228 },
1229 // R2 Sort (Key, Value)
1230 {
1231 "SortKeyValueR2",
1232 R"(HloModule sort
1233
1234 compare {
1235 p.1.lhs = s32[] parameter(2)
1236 p.1.rhs = s32[] parameter(3)
1237 p.0.lhs = f32[] parameter(0)
1238 p.0.rhs = f32[] parameter(1)
1239 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1240 }
1241
1242 ENTRY Sort {
1243 keys = f32[1024,16]{0,1} parameter(0)
1244 values = s32[1024,16]{0,1} parameter(1)
1245 ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
1246 }
1247
1248 )"
1249 },
1250 // Sort (Key, Value, Value, Value)
1251 {
1252 "SortManyValues",
1253 R"(HloModule sort
1254
1255 compare {
1256 p.1.lhs = s32[] parameter(2)
1257 p.1.rhs = s32[] parameter(3)
1258 p.2.lhs = u32[] parameter(4)
1259 p.2.rhs = u32[] parameter(5)
1260 p.3.lhs = f32[] parameter(6)
1261 p.3.rhs = f32[] parameter(7)
1262 p.0.lhs = f32[] parameter(0)
1263 p.0.rhs = f32[] parameter(1)
1264 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1265 }
1266
1267 ENTRY Sort {
1268 keys = f32[1024,16]{0,1} parameter(0)
1269 values.0 = s32[1024,16]{0,1} parameter(1)
1270 values.1 = u32[1024,16]{0,1} parameter(2)
1271 values.2 = f32[1024,16]{0,1} parameter(3)
1272 ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
1273 }
1274
1275 )"
1276 },
1277 // Sort (Key) is_stable=true
1278 {
1279 "SortKeyStable",
1280 R"(HloModule sort
1281
1282 compare {
1283 p.0.lhs = f32[] parameter(0)
1284 p.0.rhs = f32[] parameter(1)
1285 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1286 }
1287
1288 ENTRY Sort {
1289 x = f32[1024]{0} parameter(0)
1290 ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
1291 }
1292
1293 )"
1294 },
1295 // Indexed Conditional
1296 {
1297 "IndexedConditional",
1298 R"(HloModule indexed_conditional
1299
1300 Negate {
1301 x = f32[] parameter(0)
1302 ROOT negate = f32[] negate(x)
1303 }
1304
1305 Identity {
1306 y = f32[] parameter(0)
1307 ROOT copy = f32[] copy(y)
1308 }
1309
1310 Floor {
1311 z = f32[] parameter(0)
1312 ROOT floor = f32[] floor(z)
1313 }
1314
1315 ENTRY Parameters1.v4 {
1316 constant = s32[] constant(1)
1317 constant.1 = f32[] constant(56)
1318 constant.2 = f32[] constant(12)
1319 constant.3 = f32[] constant(13)
1320 ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
1321 }
1322
1323 )"
1324 },
1325 // Predicated Conditional
1326 {
1327 "PredicatedConditional",
1328 R"(HloModule pred_conditional
1329
1330 Negate {
1331 x = f32[] parameter(0)
1332 ROOT negate = f32[] negate(x)
1333 }
1334
1335 Identity {
1336 y = f32[] parameter(0)
1337 ROOT copy = f32[] copy(y)
1338 }
1339
1340 ENTRY Parameters1.v4 {
1341 constant = pred[] constant(true)
1342 constant.1 = f32[] constant(56)
1343 constant.2 = f32[] constant(12)
1344 ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
1345 }
1346
1347 )"
1348 },
1349 // CustomCall
1350 {
1351 "CustomCall",
1352 R"(HloModule custom_call
1353
1354 ENTRY CustomCall {
1355 constant = f32[1]{0} constant({12345})
1356 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
1357 }
1358
1359 )"
1360 },
1361 // Variables with non-default names
1362 {
1363 "NonDefaultNames",
1364 R"(HloModule add_constants_module
1365
1366 ENTRY add_constants {
1367 foo = f32[] constant(3.14)
1368 ROOT bar = f32[] add(foo, foo)
1369 }
1370
1371 )"
1372 },
1373 {
1374 "Dot",
1375 R"(HloModule dot
1376
1377 ENTRY dot {
1378 a = f32[2,10]{1,0} parameter(0)
1379 b = f32[10,2]{1,0} parameter(1)
1380 ROOT dot = f32[2]{0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_batch_dims={1}, rhs_contracting_dims={0}
1381 }
1382
1383 )"
1384 },
1385 {
1386 "gather",
1387 R"(HloModule gather
1388
1389 ENTRY Gather {
1390 input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1391 start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1392 ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
1393 }
1394
1395 )"
1396 },
1397 // all-reduce
1398 {
1399 "AllReduce",
1400 R"(HloModule CRS
1401
1402 add {
1403 lhs = f32[] parameter(0)
1404 rhs = f32[] parameter(1)
1405 ROOT add = f32[] add(lhs, rhs)
1406 }
1407
1408 ENTRY CRS {
1409 input = f32[8]{0} parameter(0)
1410 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
1411 }
1412
1413 )"
1414 },
1415 // all-reduce with subgroups
1416 {
1417 "AllReduceWithSubgroups",
1418 R"(HloModule CRS_Subgroups
1419
1420 add {
1421 lhs = f32[] parameter(0)
1422 rhs = f32[] parameter(1)
1423 ROOT add = f32[] add(lhs, rhs)
1424 }
1425
1426 ENTRY AllReduceWithSubgroups {
1427 input = f32[128,32]{0,1} parameter(0)
1428 ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, to_apply=add
1429 }
1430
1431 )"
1432 },
1433 // all-reduce with constrained layout
1434 {
1435 "AllReduceWithLayout",
1436 R"(HloModule CRS
1437
1438 add {
1439 lhs = f32[] parameter(0)
1440 rhs = f32[] parameter(1)
1441 ROOT add = f32[] add(lhs, rhs)
1442 }
1443
1444 ENTRY CRS {
1445 input = f32[8]{0} parameter(0)
1446 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, constrain_layout=true, to_apply=add
1447 }
1448
1449 )"
1450 },
1451 // all-reduce with all-reduce-id
1452 {
1453 "AllReduceAllReduce",
1454 R"(HloModule CRS
1455
1456 add {
1457 lhs = f32[] parameter(0)
1458 rhs = f32[] parameter(1)
1459 ROOT add = f32[] add(lhs, rhs)
1460 }
1461
1462 ENTRY CRS {
1463 input = f32[8]{0} parameter(0)
1464 crs.1 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1465 ROOT crs.0 = f32[8]{0} all-reduce(input), channel_id=1, replica_groups={{0}}, to_apply=add
1466 }
1467
1468 )"
1469 },
1470 // all-to-all
1471 {
1472 "AllToAll",
1473 R"(HloModule AllToAll
1474
1475 ENTRY AllToAll {
1476 input = f32[128,32]{0,1} parameter(0)
1477 ROOT a2a = (f32[128,32]{0,1}) all-to-all(input), replica_groups={}
1478 }
1479
1480 )"
1481 },
1482 // all-to-all with subgroups
1483 {
1484 "AllToAllWithSubgroups",
1485 R"(HloModule AllToAllWithSubgroups
1486
1487 ENTRY AllToAllWithSubgroups {
1488 p0 = f32[128,32]{0,1} parameter(0)
1489 p1 = f32[128,32]{0,1} parameter(1)
1490 ROOT a2a = (f32[128,32]{0,1}, f32[128,32]{0,1}) all-to-all(p0, p1), replica_groups={{1,2},{3,0}}
1491 }
1492
1493 )"
1494 },
1495 // collective-permute
1496 {
1497 "CollectivePermute",
1498 R"(HloModule CollectivePermute
1499
1500 ENTRY CollectivePermute {
1501 input = f32[128,32]{0,1} parameter(0)
1502 ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
1503 }
1504
1505 )"
1506 },
1507 // replica-id
1508 {
1509 "ReplicaId",
1510 R"(HloModule replica-id
1511
1512 ENTRY Replica-id {
1513 ROOT replica-id = u32[] replica-id()
1514 }
1515
1516 )"
1517 },
1518 // partition-id
1519 {
1520 "PartitionId",
1521 R"(HloModule partition-id
1522
1523 ENTRY PartitionId {
1524 ROOT id = u32[] partition-id()
1525 }
1526
1527 )"
1528 },
1529 // Iota
1530 {
1531 "Iota",
1532 R"(HloModule iota
1533
1534 ENTRY Iota {
1535 ROOT iota = f32[100]{0} iota(), iota_dimension=0
1536 }
1537
1538 )"
1539 },
1540 // custom-call with window, dim_labels and feature_group_count
1541 {
1542 "CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
1543 R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
1544
1545 ENTRY Computation {
1546 ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
1547 }
1548
1549 )"
1550 },
1551 // is_scheduled=true attribute
1552 {
1553 "ScheduledModule",
1554 R"(HloModule scheduled_module, is_scheduled=true
1555
1556 compare {
1557 p.1.lhs = s32[] parameter(2)
1558 p.1.rhs = s32[] parameter(3)
1559 p.0.lhs = f32[] parameter(0)
1560 p.0.rhs = f32[] parameter(1)
1561 ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1562 }
1563
1564 ENTRY Sort {
1565 keys = f32[1024]{0} parameter(0)
1566 values = s32[1024]{0} parameter(1)
1567 ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1568 }
1569
1570 )"
1571 },
1572 // AfterAll with multiple operands
1573 {
1574 "AfterAllWithMultipleOperands",
1575 R"(HloModule AfterAllWithMultipleOperands
1576
1577 ENTRY AfterAllWithMultipleOperands {
1578 p0 = f32[] parameter(0)
1579 token0 = token[] after-all()
1580 token1 = token[] after-all()
1581 ROOT after-all = token[] after-all(p0, token0, token1)
1582 }
1583
1584 )"
1585 },
1586 // AddDependency
1587 // A dependency chain is created from 'neg' to 'exp' using tokens.
1588 {
1589 "AddDependency",
1590 R"(HloModule AddDependency
1591
1592 ENTRY AddDependency {
1593 p = f32[] parameter(0)
1594 neg = f32[] negate(p)
1595 token0 = token[] after-all(neg)
1596 p_after_token = f32[] add-dependency(p, token0)
1597 exp = f32[] exponential(p_after_token)
1598 ROOT sum = f32[] add(neg, exp)
1599 }
1600
1601 )"
1602 },
1603
1604 // A module containing constants equal to the min/max values of various data
1605 // types.
1606 {
1607 "MinMaxValues",
1608 R"(HloModule MinMaxValues
1609
1610 ENTRY MinMaxValues {
1611 x.s8 = s8[2]{0} constant({-128, 127})
1612 x.s16 = s16[2]{0} constant({-32768, 32767})
1613 x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
1614 x.u8 = u8[2]{0} constant({0, 255})
1615 x.u16 = u16[2]{0} constant({0, 65535})
1616 x.u32 = u32[2]{0} constant({0, 4294967295})
1617 x.f16 = f16[2]{0} constant({-65504, 65504})
1618 x.bf16 = bf16[2]{0} constant({-3.39e+38, 3.39e+38})
1619 x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
1620 x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
1621 x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
1622 ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
1623 }
1624
1625 )"
1626 },
1627
1628 // Bitcast-convert usage
1629 {
1630 "BitcastConvert",
1631 R"(HloModule BitcastConvert
1632
1633 ENTRY BitcastConvertUsage {
1634 p = f32[100]{0} parameter(0)
1635 ROOT out = u32[100]{0} bitcast-convert(p)
1636 }
1637
1638 )"
1639 },
1640 {
1641 "OuterDimensionPartitions",
1642 R"(HloModule OuterDimensionPartitions
1643
1644 ENTRY Test {
1645 ROOT foo = f32[100]{0} parameter(0), outer_dimension_partitions={0,10,20}
1646 }
1647
1648 )"
1649 },
1650 });
1651 // clang-format on
1652 }
1653
1654 // The test class for those tests defined above which round-trip through the
1655 // parser and ToString is templatized on two bool parameters:
1656 //
1657 // short_form : used for the "short" test cases which use the ShortParsable
1658 // output form.
1659 // proto_round_trip : whether the module should also be round-tripped through
1660 // HloProto form. This provides much better coverage for the proto
1661 // serialization/deserialization.
1662 //
1663 // The proto_round_trip=true case also technically covers the Parser->ToString
1664 // roundtrip as well, but separating out the Parser->ToString roundtrip as its
1665 // own test provides better isolation and could conceivably catch weirdo bugs
1666 // which are hidden by interaction between the textual and proto roundtripping.
1667 template <bool short_form, bool proto_round_trip>
1668 class HloParameterizedParserTest
1669 : public ::testing::Test,
1670 public ::testing::WithParamInterface<TestData> {
1671 protected:
1672 // Expects "ToString(ParseHloModule(string)) == string", that is, parses the
1673 // string, asserts that it succeeded, stringifies the parsed module, and
1674 // checks that it equals the original string.
1675 void ExpectEqual() {
1676 std::unique_ptr<HloModule> module;
1677 const string& original = GetParam().module_string;
1678 if (GetParam().enable_verification) {
1679 auto verified_module = absl::make_unique<VerifiedHloModule>(
1680 GetParam().test_name, HloModuleConfig(),
1681 /*verifier_layout_sensitive=*/false,
1682 /*allow_mixed_precision_in_hlo_verifier=*/true,
1683 ShapeUtil::ByteSizeOfElements);
1684 TF_ASSERT_OK(verified_module->ParseHloStringAndVerifyModule(original));
1685 module = std::move(verified_module);
1686 } else {
1687 TF_ASSERT_OK_AND_ASSIGN(module, ParseAndReturnUnverifiedModule(original));
1688 }
1689 if (proto_round_trip) {
1690 TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
1691 module->ToProto(), module->config()));
1692 }
1693 if (short_form) {
1694 EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
1695 } else {
1696 EXPECT_EQ(
1697 original,
1698 module->ToString(HloPrintOptions().set_print_large_constants(true)));
1699 }
1700 }
1701 };
1702
1703 // These using shenanigans are required because the TEST_P macro doesn't like
1704 // template instantiations which contain commas.
1705 using HloParserTestLong = HloParameterizedParserTest<false, false>;
1706 using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
1707 using HloParserTestShort = HloParameterizedParserTest<true, false>;
1708 using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
1709
1710 TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
1711 TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
1712 TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
1713 TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
1714
1715 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
1716 ::testing::ValuesIn(CreateTestCases()),
1717 TestDataToString);
1718 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1719 HloParserTestLongProto,
1720 ::testing::ValuesIn(CreateTestCases()),
1721 TestDataToString);
1722 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
1723 ::testing::ValuesIn(CreateShortTestCases()),
1724 TestDataToString);
1725 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1726 HloParserTestShortProto,
1727 ::testing::ValuesIn(CreateShortTestCases()),
1728 TestDataToString);
1729
1730 class HloParserTest : public ::testing::Test {
1731 protected:
1732 static void ExpectHasSubstr(string_view s, string_view expected) {
1733 EXPECT_TRUE(absl::StrContains(s, expected))
1734 << "'" << s << "' does not contain '" << expected << "'";
1735 }
1736 StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
1737 absl::string_view hlo_text) {
1738 auto module = absl::make_unique<VerifiedHloModule>(
1739 ::testing::UnitTest::GetInstance()->current_test_info()->name(),
1740 HloModuleConfig(),
1741 /*verifier_layout_sensitive=*/false,
1742 /*allow_mixed_precision_in_hlo_verifier=*/true,
1743 ShapeUtil::ByteSizeOfElements);
1744 TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
1745 return std::move(module);
1746 }
1747 };
1748
1749 TEST_F(HloParserTest, Empty) {
1750 const string original = "";
1751 auto result = ParseAndReturnUnverifiedModule(original);
1752 EXPECT_NE(Status::OK(), result.status());
1753 }
1754
1755 TEST_F(HloParserTest, Garbage) {
1756 const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
1757 auto result = ParseAndReturnUnverifiedModule(original);
1758 EXPECT_NE(Status::OK(), result.status());
1759 }
1760
1761 TEST_F(HloParserTest, WrongOpcode) {
1762 const string original = R"(HloModule wrong_opcode:
1763
1764 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
1765 %x = f32[]{} parameter(0)
1766 %y = f32[]{} parameter(1)
1767 %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
1768 }
1769
1770 )";
1771 auto result = ParseAndReturnUnverifiedModule(original);
1772 EXPECT_NE(Status::OK(), result.status());
1773 }
1774
1775 TEST_F(HloParserTest, MetadataWithCholesky) {
1776 const string original = R"(HloModule metadata_with_cholesky
1777 ENTRY %blabla (a: f32[1,291,291]) -> f32[1,291,291] {
1778 %a = f32[1,291,291] parameter(0)
1779 %out = f32[1,291,291] cholesky(f32[1,291,291] %a), lower=true, metadata={op_type="Cholesky" op_name="Cholesky"}
1780 }
1781 )";
1782 auto result = ParseAndReturnVerifiedModule(original);
1783 EXPECT_EQ(Status::OK(), result.status());
1784 EXPECT_EQ("Cholesky", result.ValueOrDie()
1785 ->entry_computation()
1786 ->root_instruction()
1787 ->metadata()
1788 .op_name());
1789 EXPECT_EQ("Cholesky", result.ValueOrDie()
1790 ->entry_computation()
1791 ->root_instruction()
1792 ->metadata()
1793 .op_type());
1794 }
1795
1796 TEST_F(HloParserTest, WrongShape) {
1797 const string original = R"(HloModule wrong_opcode:
1798
1799 ENTRY %blabla (x: g32[]) -> g32[] {
1800 %x = g32[]{} parameter(0)
1801 }
1802
1803 )";
1804 auto result = ParseAndReturnUnverifiedModule(original);
1805 EXPECT_NE(Status::OK(), result.status());
1806 }
1807
1808 TEST_F(HloParserTest, WrongOperandsSize) {
1809 const string original = R"(HloModule wrong_opcode:
1810
1811 ENTRY %blabla (x: f32[]) -> pred[] {
1812 %x = f32[]{} parameter(0)
1813 %eq = pred[]{} compare(f32[]{} %x), direction=EQ
1814 }
1815
1816 )";
1817 auto result = ParseAndReturnUnverifiedModule(original);
1818 EXPECT_NE(Status::OK(), result.status());
1819 }
1820
1821 TEST_F(HloParserTest, OperandNotFound) {
1822 const string original = R"(HloModule operand_not_found:
1823 ENTRY %blabla (x: f32[]) -> pred[] {
1824 %x = f32[]{} parameter(0)
1825 %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
1826 }
1827 )";
1828 auto result = ParseAndReturnUnverifiedModule(original);
1829 EXPECT_NE(Status::OK(), result.status());
1830 }
1831
1832 TEST_F(HloParserTest, MoreConstants) {
1833 const string original = R"(HloModule SelectScalarS32True_module
1834
1835 ENTRY %SelectScalarS32True.v4 () -> s32[] {
1836 %constant.2 = pred[] constant(true)
1837 %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
1838 %constant = s32[] constant(42)
1839 %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
1840 }
1841
1842 )";
1843 auto result = ParseAndReturnVerifiedModule(original);
1844 TF_EXPECT_OK(result.status());
1845 // Constant instructions have no name. The string will be parsed successfully
1846 // but the constant names will not be exactly the same.
1847 }
1848
1849 TEST_F(HloParserTest, ConfigurationField) {
1850 const string original = R"(HloModule AModule
1851 ENTRY %configuration_test() -> s32[] {
1852 %constant = s32[] constant(42), backend_config="foo bar"
1853 })";
1854 auto result = ParseAndReturnVerifiedModule(original);
1855 TF_ASSERT_OK(result.status());
1856 EXPECT_EQ("foo bar", result.ValueOrDie()
1857 ->entry_computation()
1858 ->root_instruction()
1859 ->raw_backend_config_string());
1860 }
1861
1862 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
1863 const string original = R"(HloModule some_2_module
1864
1865 ENTRY %some_2 () -> f32[2] {
1866 ROOT %constant = f32[2]{0} constant({1,{2}})
1867 }
1868
1869 )";
1870 auto result = ParseAndReturnUnverifiedModule(original);
1871 EXPECT_NE(Status::OK(), result.status());
1872 ExpectHasSubstr(result.status().error_message(),
1873 "expects nested array in rank 1, but sees larger");
1874 }
1875
1876 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
1877 const string original = R"(HloModule some_2x3_module
1878
1879 ENTRY %some_2x3 () -> f32[2,3] {
1880 ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
1881 }
1882
1883 )";
1884 auto result = ParseAndReturnUnverifiedModule(original);
1885 EXPECT_NE(Status::OK(), result.status());
1886 ExpectHasSubstr(result.status().error_message(),
1887 "expects nested array in rank 2, but sees 1");
1888 }
1889
1890 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
1891 const string original = R"(HloModule some_2x3x2_module
1892
1893 ENTRY %some_2x3x2 () -> f32[2,3,2] {
1894 ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
1895 }
1896
1897 )";
1898 auto result = ParseAndReturnUnverifiedModule(original);
1899 EXPECT_NE(Status::OK(), result.status());
1900 ExpectHasSubstr(result.status().error_message(),
1901 "expects 3 elements in the [0]th element");
1902 }
1903
1904 TEST_F(HloParserTest, ConstantF16Overflow) {
1905 const string original =
1906 R"(HloModule ConstantF16Overflow_module
1907
1908 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
1909 ROOT %constant = f16[] constant(-65520)
1910 }
1911
1912 )";
1913 auto result = ParseAndReturnUnverifiedModule(original);
1914 EXPECT_NE(Status::OK(), result.status());
1915 ExpectHasSubstr(result.status().error_message(),
1916 "is out of range for literal's primitive type F16");
1917 }
1918
1919 TEST_F(HloParserTest, ConstantBf16NoOverflow) {
1920 // 65505 is in range for bf16.
1921 const string original = R"(
1922 HloModule test_module
1923 ENTRY test {
1924 ROOT c = bf16[] constant(-65505)
1925 })";
1926 EXPECT_EQ(Status::OK(), ParseAndReturnVerifiedModule(original).status());
1927 }
1928
1929 TEST_F(HloParserTest, ConstantBf16Overflow) {
1930 // 1e100 is out of range for bf16.
1931 const string original = R"(
1932 HloModule test_module
1933 ENTRY test {
1934 ROOT c = bf16[] constant(1e100)
1935 })";
1936 ExpectHasSubstr(
1937 ParseAndReturnUnverifiedModule(original).status().error_message(),
1938 "out of range");
1939 }
1940
1941 TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
1942 const string original = R"(
1943 HloModule ConstantUnsignedUnderflow_module
1944 ENTRY %ConstantUnsignedUnderflow () -> u64[] {
1945 ROOT %constant = u64[] constant(-1)
1946 })";
1947 auto result = ParseAndReturnUnverifiedModule(original);
1948 EXPECT_NE(Status::OK(), result.status());
1949 ExpectHasSubstr(result.status().error_message(),
1950 "is out of range for literal's primitive type U64");
1951 }
1952
1953 TEST_F(HloParserTest, ConstantUnsignedOverflow) {
1954 const string original = R"(
1955 HloModule ConstantUnsignedOverflow_module
1956 ENTRY %ConstantUnsignedOverflow () -> u32[] {
1957 ROOT %constant = u32[] constant(4294967296)
1958 })";
1959 auto result = ParseAndReturnUnverifiedModule(original);
1960 EXPECT_NE(Status::OK(), result.status());
1961 ExpectHasSubstr(result.status().error_message(),
1962 "is out of range for literal's primitive type U32");
1963 }
1964
1965 TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
1966 const string original = R"(
1967 HloModule ConstantUnsignedOverflow_module
1968 ENTRY %ConstantUnsignedOverflow () -> u64[] {
1969 ROOT %constant = u64[] constant(9223372036854775808)
1970 })";
1971 auto result = ParseAndReturnUnverifiedModule(original);
1972 EXPECT_NE(Status::OK(), result.status());
1973 }
1974
1975 TEST_F(HloParserTest, ConstantC64Overflow) {
1976 const string original = R"(
1977 HloModule test_module
1978 ENTRY test () -> c64[] {
1979 ROOT c = c64[] constant((1e100, 0))
1980 })";
1981 auto result = ParseAndReturnUnverifiedModule(original);
1982 EXPECT_NE(Status::OK(), result.status());
1983 }
1984
1985 TEST_F(HloParserTest, ConstantC64Underflow) {
1986 const string original = R"(
1987 HloModule test_module
1988 ENTRY test () -> c64[] {
1989 ROOT c = c64[] constant((0, -1e100))
1990 })";
1991 auto result = ParseAndReturnUnverifiedModule(original);
1992 EXPECT_NE(Status::OK(), result.status());
1993 }
1994
1995 TEST_F(HloParserTest, ConstantF64Overflow) {
1996 const string original = R"(
1997 HloModule test_module
1998 ENTRY test {
1999 ROOT c = f64[] constant(1.8e308)
2000 })";
2001 auto result = ParseAndReturnUnverifiedModule(original);
2002 EXPECT_NE(Status::OK(), result.status());
2003 }
2004
2005 TEST_F(HloParserTest, ConstantF64Underflow) {
2006 const string original = R"(
2007 HloModule test_module
2008 ENTRY test {
2009 ROOT c = f64[] constant(-1.8e308)
2010 })";
2011 auto result = ParseAndReturnUnverifiedModule(original);
2012 EXPECT_NE(Status::OK(), result.status());
2013 }
2014
2015 TEST_F(HloParserTest, ConstantWithExp) {
2016 const string original = R"(HloModule ConstantWithExp_module
2017
2018 ENTRY %ConstantWithExp.v4 () -> f32[] {
2019 %constant.1 = f32[] constant(3e+2)
2020 }
2021
2022 )";
2023 auto result = ParseAndReturnVerifiedModule(original);
2024 TF_EXPECT_OK(result.status());
2025 // The string will be parsed successfully but the output strings are not
2026 // exactly the same, because "3e2" is parsed into value 300 and will be
2027 // printed as "300".
2028 }
2029
2030 TEST_F(HloParserTest, ShortConstant) {
2031 const string original = R"(HloModule ShortConstant_module
2032
2033 ENTRY %ShortConstant.v4 () -> f32[67,89] {
2034 ROOT %constant.1 = f32[67,89]{1,0} constant({...})
2035 }
2036
2037 )";
2038 auto result = ParseAndReturnVerifiedModule(original);
2039 TF_EXPECT_OK(result.status());
2040 EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
2041 }
2042
2043 TEST_F(HloParserTest, AttributesAnyOrder) {
2044 const string original = R"(HloModule any_order_module
2045
2046 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,4,1] {
2047 %input = f32[1,2,1]{2,1,0} parameter(0)
2048 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2049 %filter = f32[1,1,1]{2,1,0} parameter(1)
2050 ROOT %convolution = f32[1,4,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=1}
2051 }
2052
2053 )";
2054 TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2055 }
2056
2057 TEST_F(HloParserTest, InvalidDimLabels) {
2058 string prefix = R"(HloModule invalid_dim_labels_module
2059
2060 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2061 %input = f32[1,2,1]{2,1,0} parameter(0)
2062 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2063 %filter = f32[1,1,1]{2,1,0} parameter(1)
2064 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
2065 string suffix = R"(
2066 }
2067
2068 )";
2069
2070 ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2071 absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
2072 .status()
2073 .error_message(),
2074 "expects dim labels pattern");
2075
2076 ExpectHasSubstr(ParseAndReturnUnverifiedModule(
2077 absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
2078 .status()
2079 .error_message(),
2080 "must have the same rank");
2081 }
2082
2083 TEST_F(HloParserTest, UnexpectedAttribute) {
2084 const string original = R"(HloModule unexpected_attr_module
2085
2086 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2087 %token0 = token[] after-all()
2088 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2089 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2090 ROOT %constant = f32[] constant(2.1)
2091 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
2092 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2093 }
2094
2095 )";
2096 ExpectHasSubstr(
2097 ParseAndReturnUnverifiedModule(original).status().error_message(),
2098 "unexpected attribute \"calls\"");
2099 }
2100
2101 TEST_F(HloParserTest, MissingAttribute) {
2102 const string original = R"(HloModule missing_attr_module
2103
2104 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2105 %token0 = token[] after-all()
2106 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2107 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2108 ROOT %constant = f32[] constant(-2.1)
2109 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
2110 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2111 }
2112
2113 )";
2114 ExpectHasSubstr(
2115 ParseAndReturnUnverifiedModule(original).status().error_message(),
2116 "attribute channel_id is expected but not seen");
2117 }
2118
2119 TEST_F(HloParserTest, PredecessorUndefined) {
2120 const string original = R"(HloModule pre_not_found_module
2121
2122 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
2123 %token0 = token[] after-all()
2124 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
2125 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
2126 ROOT %constant = f32[] constant(2.1)
2127 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
2128 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
2129 }
2130
2131 )";
2132 ExpectHasSubstr(
2133 ParseAndReturnUnverifiedModule(original).status().error_message(),
2134 "'done' is not defined");
2135 }
2136
2137 TEST_F(HloParserTest, SliceAllowOmitStride1) {
2138 const string original = R"(HloModule slice_module
2139
2140 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
2141 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
2142 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
2143 }
2144
2145 )";
2146 TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2147 }
2148
2149 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
2150 const string original = R"(HloModule window_pad_module
2151
2152 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2153 %input = f32[1,2,1]{2,1,0} parameter(0)
2154 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2155 %filter = f32[1,1,1]{2,1,0} parameter(1)
2156 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
2157 }
2158
2159 )";
2160 ExpectHasSubstr(
2161 ParseAndReturnUnverifiedModule(original).status().error_message(),
2162 "expects padding_low and padding_high separated by '_'");
2163 }
2164
2165 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
2166 const string original = R"(HloModule test_comma_module
2167
2168 ENTRY %test_comma.v4 () -> f32[] {
2169 ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
2170 }
2171
2172 )";
2173 TF_EXPECT_OK(ParseAndReturnVerifiedModule(original).status());
2174 }
2175
2176 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
2177 const string original = R"(HloModule custom_call:
2178
2179 ENTRY %CustomCall () -> f32[1] {
2180 %constant = f32[1]{0} constant({12345})
2181 ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
2182 })";
2183 ExpectHasSubstr(
2184 ParseAndReturnUnverifiedModule(original).status().error_message(),
2185 "Shape of computation CustomCall, f32[1], is not compatible "
2186 "with that of its root instruction foo, f32[1,2,3]");
2187 }
2188
2189 TEST_F(HloParserTest, EntryComputationWithLayout) {
2190 const string original = R"(HloModule layout:
2191 add_F32.v3 {
2192 lhs = f32[] parameter(0)
2193 rhs = f32[] parameter(1)
2194 ROOT add = f32[] add(lhs, rhs)
2195 }
2196
2197 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
2198 input = f32[8,16,256]{0,1,2} parameter(0)
2199 constant = f32[] constant(0)
2200 ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
2201 })";
2202
2203 auto module = ParseAndReturnVerifiedModule(original);
2204 TF_ASSERT_OK(module.status());
2205 auto program_layout = module.ValueOrDie()->entry_computation_layout();
2206 ASSERT_EQ(program_layout.parameter_count(), 1);
2207 auto param_layout = program_layout.parameter_layout(0).layout();
2208 auto result_layout = program_layout.result_layout().layout();
2209 EXPECT_TRUE(
2210 LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
2211 << "actual layout of parameter(0) is "
2212 << LayoutUtil::HumanString(param_layout);
2213 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
2214 << "actual layout of result is "
2215 << LayoutUtil::HumanString(result_layout);
2216 }
2217
2218 TEST_F(HloParserTest, NoEntry) {
2219 const string original = R"(HloModule no_entry:
2220 c1 {
2221 const1 = f32[1]{0} constant({12345})
2222 }
2223 c2 {
2224 const2 = f32[1]{0} constant({67890})
2225 })";
2226 auto module = ParseAndReturnVerifiedModule(original);
2227 TF_ASSERT_OK(module.status());
2228 EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
2229 }
2230
2231 TEST_F(HloParserTest, NoRoot) {
2232 const string original = R"(HloModule no_root:
2233 ENTRY consts {
2234 first = f32[1]{0} constant({12345})
2235 last = f32[1]{0} constant({67890})
2236 })";
2237 auto module = ParseAndReturnVerifiedModule(original);
2238 TF_ASSERT_OK(module.status());
2239 EXPECT_EQ(
2240 module.ValueOrDie()->entry_computation()->root_instruction()->name(),
2241 "last");
2242 }
2243
2244 TEST_F(HloParserTest, Comments) {
2245 const string original = R"(/* module description. */
2246 HloModule comments:
2247
2248 ENTRY /*comment*/ c1 {
2249 /* blah */
2250 ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
2251 /* comment */
2252 }
2253
2254 /* something else */
2255
2256 )";
2257 auto module = ParseAndReturnVerifiedModule(original);
2258 TF_ASSERT_OK(module.status());
2259 }
2260
2261 TEST_F(HloParserTest, MultilineComments) {
2262 const string original = R"(HloModule multiline_comment:
2263 ENTRY c1 {
2264 /*
2265 ROOT foo = f32[1]{0} constant({12345})
2266 */
2267 ROOT const1 = f32[1]{0} constant({12345})
2268 /*
2269 a
2270 b
2271 c
2272 d
2273
2274 */
2275 })";
2276 auto module = ParseAndReturnVerifiedModule(original);
2277 TF_ASSERT_OK(module.status());
2278 }
2279
2280 TEST_F(HloParserTest, UnterminatedComment) {
2281 const string original = R"(HloModule unterminated_comment:
2282 ENTRY c1 {
2283 /* unterminated
2284 ROOT const1 = f32[1]{0} constant({12345})
2285 })";
2286 // Verify that the error message points to the beginning of the unterminated
2287 // comment.
2288 ExpectHasSubstr(
2289 ParseAndReturnUnverifiedModule(original).status().error_message(),
2290 "/* unterminated\n^");
2291 }
2292
2293 TEST_F(HloParserTest, SlashSlashComments) {
2294 const string original = R"(HloModule slash_slash_comment:
2295 // Garbage
2296 ENTRY c1 {
2297 // Foo bar
2298 ROOT const1 = f32[1]{0} constant({12345}) // Something else
2299 })";
2300 auto module = ParseAndReturnVerifiedModule(original);
2301 TF_ASSERT_OK(module.status());
2302 }
2303
2304 TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
2305 const string original =
2306 "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
2307 "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
2308 auto module = ParseAndReturnVerifiedModule(original);
2309 TF_ASSERT_OK(module.status());
2310 }
2311
2312 TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
2313 const string original =
2314 "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
2315 "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
2316 auto module = ParseAndReturnVerifiedModule(original);
2317 TF_ASSERT_OK(module.status());
2318 }
2319
2320 TEST_F(HloParserTest, MultipleEntries) {
2321 const string original = R"(HloModule multiple_entries:
2322 ENTRY c1 {
2323 const1 = f32[1]{0} constant({12345})
2324 }
2325 ENTRY c2 {
2326 const2 = f32[1]{0} constant({67890})
2327 })";
2328 ExpectHasSubstr(
2329 ParseAndReturnUnverifiedModule(original).status().error_message(),
2330 "expects only one ENTRY");
2331 }
2332
2333 TEST_F(HloParserTest, MultipleRoots) {
2334 const string original = R"(HloModule multiple_roots:
2335 ENTRY consts {
2336 ROOT const1 = f32[1]{0} constant({12345})
2337 ROOT const2 = f32[1]{0} constant({12345})
2338 })";
2339 ExpectHasSubstr(
2340 ParseAndReturnUnverifiedModule(original).status().error_message(),
2341 "one computation should have only one ROOT");
2342 }
2343
2344 TEST_F(HloParserTest, ComputationExists) {
2345 const string original = R"(HloModule comp_exists
2346 comp {
2347 const1 = f32[1]{0} constant({12345})
2348 }
2349 comp {
2350 const2 = f32[1]{0} constant({67890})
2351 })";
2352 ExpectHasSubstr(
2353 ParseAndReturnUnverifiedModule(original).status().error_message(),
2354 R"(was parsing 2:1: error: computation previously defined here
2355 comp {
2356 ^)");
2357 }
2358
2359 TEST_F(HloParserTest, CrossComputationLookup) {
2360 const string original = R"(HloModule cross_computation_lookup:
2361 tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
2362 ROOT aparam = (s32[], s32[]) parameter(0)
2363 }
2364
2365 tcallb (b: (s32[], s32[])) -> s32[] {
2366 rparam = (s32[], s32[]) parameter(0)
2367 ROOT gte0 = s32[] get-tuple-element(aparam), index=0
2368 }
2369
2370 ENTRY entry {
2371 param = (s32[], s32[]) parameter(0)
2372 call0 = (s32[], s32[]) call(param), to_apply=tcalla
2373 ROOT call1 = s32[] call(param), to_apply=tcallb
2374 })";
2375 ExpectHasSubstr(
2376 ParseAndReturnUnverifiedModule(original).status().error_message(),
2377 "was parsing 8:39: error: instruction does not exist: aparam");
2378 }
2379
2380 TEST_F(HloParserTest, SameNameDiffComputations) {
2381 const string original = R"(HloModule same_names:
2382 add {
2383 p0 = f32[] parameter(0)
2384 p1 = f32[] parameter(1)
2385 ROOT result = f32[] add(p0, p1)
2386 }
2387
2388 ENTRY ReduceR3ToR2 {
2389 p0 = f32[8,16,256]{2,1,0} parameter(0)
2390 p1 = f32[] constant(0)
2391 ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
2392 }
2393 )";
2394 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(original));
2395 ASSERT_NE(module->entry_computation(), nullptr);
2396 EXPECT_THAT(module->entry_computation()->root_instruction(),
2397 GmockMatch(m::Reduce()));
2398 }
2399
2400 TEST_F(HloParserTest, ParseSharding) {
2401 const string original = "{maximal device=42}";
2402 TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
2403 EXPECT_EQ(sharding.ToString(), original);
2404 }
2405
2406 TEST_F(HloParserTest, ParseFrontendAttributes) {
2407 const string original = "{attr_a=test_a,attr_b=b}";
2408 TF_ASSERT_OK_AND_ASSIGN(FrontendAttributes frontend_attributes,
2409 ParseFrontendAttributes(original));
2410 EXPECT_EQ(FrontendAttributesToString(frontend_attributes), original);
2411 }
2412
2413 TEST_F(HloParserTest, ParseWindow) {
2414 Window original = window_util::MakeWindow({1, 2, 3});
2415 TF_ASSERT_OK_AND_ASSIGN(Window parsed,
2416 ParseWindow(window_util::ToString(original)))
2417 EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
2418 }
2419
2420 TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
2421 const string original = "b0f_0io->b0f";
2422 TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
2423 ParseConvolutionDimensionNumbers(original));
2424 EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
2425 }
2426
2427 TEST_F(HloParserTest, ParseReplicaGroups) {
2428 const string original = "{{0,1},{2,3}}";
2429 TF_ASSERT_OK_AND_ASSIGN(std::vector<ReplicaGroup> replica_groups,
2430 ParseReplicaGroupsOnly(original));
2431 EXPECT_EQ(original, ReplicaGroupsToString(replica_groups));
2432 }
2433
2434 TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
2435 const string original = "0_1x2_3";
2436 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2437 EXPECT_EQ(original, PaddingConfigToString(dnums));
2438 }
2439
2440 TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
2441 const string original = "0_1_0x2_3_4";
2442 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2443 EXPECT_EQ(original, PaddingConfigToString(dnums));
2444 }
2445
2446 TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
2447 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
2448 // The extra "_0" gets added to the canonical string because the other dim has
2449 // interior padding.
2450 EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
2451 }
2452
2453 TEST_F(HloParserTest, NontupleInfeed) {
2454 const string original = R"(HloModule nontuple_infeed:
2455 ENTRY nontuple_infeed {
2456 token0 = token[] after-all()
2457 ROOT infeed = pred[] infeed(token0)
2458 })";
2459 ExpectHasSubstr(
2460 ParseAndReturnUnverifiedModule(original).status().error_message(),
2461 "infeed must have a non-empty tuple shape");
2462 }
2463
2464 TEST(HloParserSingleOpTest, SingleOp) {
2465 const string text =
2466 "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
2467 "f32[2,4]{1,0} %x)";
2468 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2469 const HloComputation* computation = module->entry_computation();
2470 ASSERT_NE(computation, nullptr);
2471 EXPECT_THAT(computation->root_instruction(),
2472 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2473 }
2474
2475 TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
2476 const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
2477 StatusOr<std::unique_ptr<HloModule>> module =
2478 ParseAndReturnUnverifiedModule(text);
2479 ASSERT_TRUE(!module.status().ok());
2480 LOG(INFO) << "Status: " << module.status();
2481 EXPECT_THAT(module.status().ToString(),
2482 ::testing::HasSubstr("expects '=' in instruction"));
2483 }
2484
2485 TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
2486 const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
2487 StatusOr<std::unique_ptr<HloModule>> module =
2488 ParseAndReturnUnverifiedModule(text);
2489 ASSERT_TRUE(!module.status().ok());
2490 LOG(INFO) << "Status: " << module.status();
2491 EXPECT_THAT(module.status().ToString(),
2492 ::testing::HasSubstr("Operand had no shape in HLO text"));
2493 }
2494
2495 TEST(HloParserSingleOpTest, SingleOpNoNames) {
2496 const string text =
2497 "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2498 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2499 const HloComputation* computation = module->entry_computation();
2500 ASSERT_NE(computation, nullptr);
2501 EXPECT_THAT(computation->root_instruction(),
2502 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2503 }
2504
2505 TEST(HloParserSingleOpTest, CanonicalOp) {
2506 const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2507 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2508 const HloComputation* computation = module->entry_computation();
2509 ASSERT_NE(computation, nullptr);
2510 EXPECT_THAT(computation->root_instruction(),
2511 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2512 EXPECT_EQ(
2513 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2514 text);
2515 }
2516
2517 TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
2518 const string text =
2519 R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
2520 {
2521 tmp_0 = f32[5,10]{1,0} parameter(0)
2522 tmp_1 = f32[20,10]{1,0} parameter(1)
2523 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2524 {
2525 tmp_0 = f32[5,10]{1,0} parameter(0)
2526 tmp_1 = f32[20,10]{1,0} parameter(1)
2527 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2528 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2529 }
2530 }, body=
2531 {
2532 tmp_0 = f32[5,10]{1,0} parameter(0)
2533 tmp_1 = f32[20,10]{1,0} parameter(1)
2534 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2535 {
2536 tmp_0 = f32[5,10]{1,0} parameter(0)
2537 tmp_1 = f32[20,10]{1,0} parameter(1)
2538 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2539 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2540 }
2541 })";
2542
2543 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2544 const HloComputation* computation = module->entry_computation();
2545 ASSERT_NE(computation, nullptr);
2546 EXPECT_EQ(
2547 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2548 text);
2549 }
2550
2551 TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
2552 const string text =
2553 R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
2554 {
2555 tmp_0 = f32[5,10]{1,0} parameter(0)
2556 ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
2557 },
2558 {
2559 tmp_0 = f32[5,10]{1,0} parameter(0)
2560 ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
2561 },
2562 {
2563 tmp_0 = f32[5,10]{1,0} parameter(0)
2564 ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
2565 }
2566 })";
2567
2568 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2569 const HloComputation* computation = module->entry_computation();
2570 ASSERT_NE(computation, nullptr);
2571 EXPECT_EQ(
2572 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2573 text);
2574 }
2575
2576 TEST(HloParserSingleOpTest, SingleOpWithNested) {
2577 const string text =
2578 R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
2579 {
2580 %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
2581 %param_1 = f32[2]{0} parameter(1)
2582 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
2583 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
2584 })";
2585
2586 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2587 const HloComputation* computation = module->entry_computation();
2588 ASSERT_NE(computation, nullptr);
2589 EXPECT_THAT(computation->root_instruction(),
2590 GmockMatch(m::Op()
2591 .WithOpcode(HloOpcode::kFusion)
2592 .WithNumOperands(2)
2593 .WithOperand(0, m::Parameter(0))
2594 .WithOperand(1, m::Parameter(1))));
2595 }
2596
2597 TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
2598 const string text =
2599 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2600 {
2601 result = f32[] add(f32[] x, f32[] y)
2602 })";
2603 auto status = ParseAndReturnUnverifiedModule(text).status();
2604 ASSERT_FALSE(status.ok());
2605 EXPECT_THAT(status.error_message(),
2606 ::testing::HasSubstr("does not exist: x"));
2607 }
2608
2609 TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
2610 const string text =
2611 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2612 {
2613 f32[] add(f32[] x, f32[] y)
2614 })";
2615 auto status = ParseAndReturnUnverifiedModule(text).status();
2616 ASSERT_FALSE(status.ok());
2617 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2618 }
2619
2620 TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
2621 const string text =
2622 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2623 {
2624 result = f32[] add(f32[], f32[])
2625 })";
2626 auto status = ParseAndReturnUnverifiedModule(text).status();
2627 ASSERT_FALSE(status.ok());
2628 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2629 }
2630
2631 TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
2632 const string text =
2633 R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
2634 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnUnverifiedModule(text));
2635 const HloComputation* computation = module->entry_computation();
2636 ASSERT_NE(computation, nullptr);
2637 EXPECT_THAT(computation->root_instruction(),
2638 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
2639 auto* convolution =
2640 Cast<HloConvolutionInstruction>(computation->root_instruction());
2641 EXPECT_EQ(convolution->feature_group_count(), 1);
2642 }
2643
2644 TEST(HloParserSingleOpTest, MultipleOpsProducesError) {
2645 const string text = R"(
2646 param = f32[2,5,1,3] parameter(0)
2647 transpose = f32[1,5,2,3] transpose(param), dimensions={2,1,0,3}
2648 )";
2649 auto status = ParseAndReturnUnverifiedModule(text).status();
2650 ASSERT_FALSE(status.ok());
2651 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("Expected eof"));
2652 }
2653
2654 TEST_F(HloParserTest, IsScheduledIsFalse) {
2655 const string text = R"(
2656 HloModule axpy_module, is_scheduled=false
2657
2658 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2659 %alpha = f32[] parameter(0)
2660 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2661 %x = f32[2,4]{1,0} parameter(1)
2662 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2663 %y = f32[2,4]{1,0} parameter(2)
2664 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2665 }
2666 )";
2667 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2668 ASSERT_FALSE(module->has_schedule());
2669 }
2670
2671 TEST_F(HloParserTest, IsScheduledNotPresent) {
2672 const string text = R"(
2673 HloModule axpy_module
2674
2675 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2676 %alpha = f32[] parameter(0)
2677 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2678 %x = f32[2,4]{1,0} parameter(1)
2679 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2680 %y = f32[2,4]{1,0} parameter(2)
2681 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2682 }
2683 )";
2684 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2685 ASSERT_FALSE(module->has_schedule());
2686 }
2687
2688 TEST_F(HloParserTest, IsScheduledIsTrue) {
2689 const string text = R"(
2690 HloModule axpy_module, is_scheduled=true
2691
2692 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2693 %alpha = f32[] parameter(0)
2694 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2695 %x = f32[2,4]{1,0} parameter(1)
2696 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2697 %y = f32[2,4]{1,0} parameter(2)
2698 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2699 }
2700 )";
2701 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2702 ASSERT_TRUE(module->has_schedule());
2703 TF_ASSERT_OK(module->schedule().Verify());
2704 EXPECT_EQ(module->schedule().sequences().size(), 1);
2705 ASSERT_TRUE(
2706 module->schedule().is_computation_scheduled(module->entry_computation()));
2707 EXPECT_THAT(
2708 module->schedule().sequence(module->entry_computation()).instructions(),
2709 ::testing::ElementsAre(
2710 GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2711 GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
2712 GmockMatch(m::Parameter()), GmockMatch(m::Add())));
2713 }
2714
2715 TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
2716 // As above but in with a different schedule order.
2717 const string text = R"(
2718 HloModule axpy_module, is_scheduled=true
2719
2720 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2721 %alpha = f32[] parameter(0)
2722 %x = f32[2,4]{1,0} parameter(1)
2723 %y = f32[2,4]{1,0} parameter(2)
2724 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2725 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2726 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2727 }
2728 )";
2729 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2730 ASSERT_TRUE(module->has_schedule());
2731 TF_ASSERT_OK(module->schedule().Verify());
2732 EXPECT_EQ(module->schedule().sequences().size(), 1);
2733 ASSERT_TRUE(
2734 module->schedule().is_computation_scheduled(module->entry_computation()));
2735 EXPECT_THAT(
2736 module->schedule().sequence(module->entry_computation()).instructions(),
2737 ::testing::ElementsAre(
2738 GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
2739 GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2740 GmockMatch(m::Multiply()), GmockMatch(m::Add())));
2741 }
2742
2743 TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
2744 const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
2745
2746 ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2747 %p0 = f32[42,2,3]{0,1,2} parameter(0)
2748 %p1 = f32[123,4]{0,1} parameter(1)
2749 ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
2750 }
2751
2752 )";
2753 ExpectHasSubstr(
2754 ParseAndReturnUnverifiedModule(original).status().error_message(),
2755 "Expected 2 operand layout constraints, 1 given");
2756 }
2757
2758 TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
2759 const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
2760
2761 ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2762 %p0 = f32[42,2,3]{0,1,2} parameter(0)
2763 %p1 = f32[123,4]{0,1} parameter(1)
2764 ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
2765 }
2766
2767 )";
2768 ExpectHasSubstr(
2769 ParseAndReturnUnverifiedModule(original).status().error_message(),
2770 "operand 1 is not compatible with operand shape");
2771 }
2772
2773 TEST_F(HloParserTest, AllowShapeWhitespace) {
2774 const string text = R"(
2775 HloModule module
2776
2777 ENTRY entry {
2778 ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
2779 }
2780 )";
2781 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(text));
2782 }
2783
2784 TEST_F(HloParserTest, ShapeMismatchInOperand) {
2785 const string text = R"(
2786 HloModule foobar
2787
2788 ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
2789 %p = f32[2,2] parameter(0)
2790 %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
2791 ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
2792 }
2793 )";
2794
2795 ExpectHasSubstr(ParseAndReturnUnverifiedModule(text).status().error_message(),
2796 "The declared operand shape f32[2,5]{1,0} is not compatible"
2797 " with the shape of the operand instruction f32[2,2]{1,0}.");
2798 }
2799
2800 TEST_F(HloParserTest, ParseShapeStringR2F32) {
2801 string shape_string = "f32[123,456]";
2802 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2803 Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
2804 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2805 << "expected: " << ShapeUtil::HumanString(expected)
2806 << "actual: " << ShapeUtil::HumanString(actual);
2807 }
2808
2809 TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
2810 string shape_string = "(f32[1572864],s8[5120,1024])";
2811 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2812 Shape expected =
2813 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
2814 ShapeUtil::MakeShape(S8, {5120, 1024})});
2815 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2816 << "expected: " << ShapeUtil::HumanString(expected)
2817 << "actual: " << ShapeUtil::HumanString(actual);
2818 }
2819
2820 TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
2821 string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
2822 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2823 Shape expected = ShapeUtil::MakeTupleShape({
2824 ShapeUtil::MakeShape(F32, {1}),
2825 ShapeUtil::MakeTupleShape(
2826 {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
2827 ShapeUtil::MakeOpaqueShape(),
2828 ShapeUtil::MakeShape(F32, {3}),
2829 });
2830 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2831 << "expected: " << ShapeUtil::HumanString(expected)
2832 << "actual: " << ShapeUtil::HumanString(actual);
2833 }
2834
2835 TEST_F(HloParserTest, ParseShapeStringWithLayout) {
2836 string shape_string = "f32[123,456]{0,1}";
2837 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2838 Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
2839 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2840 << "expected: " << ShapeUtil::HumanString(expected)
2841 << "actual: " << ShapeUtil::HumanString(actual);
2842 }
2843
2844 TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
2845 // One tile.
2846 string shape_string = "f32[123,456]{0,1:T(2,128)}";
2847 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2848 Shape expected =
2849 ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})});
2850 EXPECT_EQ(expected, actual)
2851 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2852 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2853
2854 // Tile with negative dimension size for combining dimensions.
2855 shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
2856 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2857 expected =
2858 ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2},
2859 {Tile({2, Tile::kCombineDimension, 128})});
2860 EXPECT_EQ(expected, actual)
2861 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2862 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2863
2864 // Two tiles.
2865 shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
2866 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2867 expected = ShapeUtil::MakeShapeWithLayout(
2868 BF16, {123, 456, 789}, {2, 1, 0},
2869 {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
2870 EXPECT_EQ(expected, actual)
2871 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2872 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2873
2874 // Tile with element size in bits.
2875 shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
2876 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2877 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
2878 {Tile({2, 128})}, 1);
2879 EXPECT_EQ(expected, actual)
2880 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2881 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2882
2883 // Element size in bits without tile.
2884 shape_string = "pred[123,456]{1,0:E(1)}";
2885 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2886 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1);
2887 EXPECT_EQ(expected, actual)
2888 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2889 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2890
2891 // Wrong minor_to_major.
2892 shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
2893 auto result = ParseShape(shape_string);
2894 ExpectHasSubstr(result.status().error_message(),
2895 "Dimensions size is 3, but minor to major size is 1.");
2896 }
2897
2898 TEST_F(HloParserTest, ParseShapeStringWithMemorySpaceLayout) {
2899 // Tile, element size, and memory space.
2900 string shape_string = "pred[123,456]{1,0:T(2,128)E(1)S(3)}";
2901 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2902 Shape expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
2903 {Tile({2, 128})}, 1, 3);
2904 EXPECT_EQ(expected, actual)
2905 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2906 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2907
2908 // Element size and memory space.
2909 shape_string = "pred[123,456]{1,0:E(1)S(3)}";
2910 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2911 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1, 3);
2912 EXPECT_EQ(expected, actual)
2913 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2914 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2915
2916 // Memory space only.
2917 shape_string = "pred[123,456]{1,0:S(3)}";
2918 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2919 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 0, 3);
2920 EXPECT_EQ(expected, actual)
2921 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2922 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2923 }
2924
2925 TEST_F(HloParserTest, ParseOpaqueType) {
2926 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
2927 Shape expected = ShapeUtil::MakeOpaqueShape();
2928 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2929 << "expected: " << ShapeUtil::HumanString(expected)
2930 << "actual: " << ShapeUtil::HumanString(actual);
2931 }
2932
2933 TEST_F(HloParserTest, ParseTokenType) {
2934 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
2935 Shape expected = ShapeUtil::MakeTokenShape();
2936 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2937 << "expected: " << ShapeUtil::HumanString(expected)
2938 << "actual: " << ShapeUtil::HumanString(actual);
2939 }
2940
2941 TEST_F(HloParserTest, ParseInvalidShapeString) {
2942 string shape_strings[] = {"f32[123,456]foobar{0,1}", "f32[123,456]{foo}",
2943 "f32[123,456]dense{foo}"};
2944 for (const string& shape_string : shape_strings) {
2945 StatusOr<Shape> result = ParseShape(shape_string);
2946 ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
2947 }
2948 }
2949
2950 TEST_F(HloParserTest, ParseDynamicArray) {
2951 string shape_string = "f32[123,<=456]";
2952 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2953 Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
2954 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2955 << "expected: " << ShapeUtil::HumanString(expected)
2956 << "actual: " << ShapeUtil::HumanString(actual);
2957 }
2958
2959 TEST_F(HloParserTest, ParseDynamicTuple) {
2960 string shape_string = "(f32[42], u32[<=123,<=456])";
2961 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2962 Shape expected = ShapeUtil::MakeTupleShape(
2963 {ShapeUtil::MakeShape(F32, {42}),
2964 ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
2965 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2966 << "expected: " << ShapeUtil::HumanString(expected)
2967 << "actual: " << ShapeUtil::HumanString(actual);
2968 }
2969
2970 TEST_F(HloParserTest, NegativeParameterNumber) {
2971 const string hlo_string = "par0 = f32[3,5] parameter(-1)";
2972 auto result = ParseAndReturnUnverifiedModule(hlo_string);
2973 ASSERT_FALSE(result.status().ok());
2974 EXPECT_THAT(result.status().error_message(),
2975 ::testing::HasSubstr("parameter number must be >= 0"));
2976 }
2977
2978 TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
2979 const string hlo_string =
2980 "par0 = (f32[3,5], f32[]) parameter(0), "
2981 "parameter_replication={true,false,true}";
2982 auto result = ParseAndReturnUnverifiedModule(hlo_string);
2983 ASSERT_FALSE(result.status().ok());
2984 EXPECT_THAT(result.status().error_message(),
2985 ::testing::HasSubstr("parameter has 2 leaf buffers, but "
2986 "parameter_replication has 3 elements"));
2987 }
2988
2989 TEST_F(HloParserTest, CheckIndexedConditionalDimension) {
2990 const char* const hlo_string = R"(
2991 HloModule Module
2992
2993 branch0 {
2994 tparam = f32[4] parameter(0)
2995 ROOT tgte1 = f32[4] ceil(tparam)
2996 }
2997
2998 branch1 {
2999 fparam = f32[4] parameter(0)
3000 ROOT fgte1 = f32[4] floor(fparam)
3001 }
3002
3003 ENTRY entry {
3004 p0 = f32[4] parameter(0)
3005 b0 = s32[2] parameter(1)
3006 ROOT conditional = f32[4] conditional(b0, p0, p0),
3007 branch_computations={branch0, branch1}
3008 }
3009 )";
3010 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3011 EXPECT_NE(Status::OK(), result.status());
3012 EXPECT_THAT(result.status().error_message(),
3013 ::testing::HasSubstr("The first operand must be a scalar"));
3014 }
3015
3016 TEST_F(HloParserTest, CheckIndexedConditionalElementType) {
3017 const char* const hlo_string = R"(
3018 HloModule Module
3019
3020 branch0 {
3021 tparam = f32[4] parameter(0)
3022 ROOT tgte1 = f32[4] ceil(tparam)
3023 }
3024
3025 branch1 {
3026 fparam = f32[4] parameter(0)
3027 ROOT fgte1 = f32[4] floor(fparam)
3028 }
3029
3030 ENTRY entry {
3031 p0 = f32[4] parameter(0)
3032 b0 = f32[] parameter(1)
3033 ROOT conditional = f32[4] conditional(b0, p0, p0),
3034 branch_computations={branch0, branch1}
3035 }
3036 )";
3037 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3038 EXPECT_NE(Status::OK(), result.status());
3039 EXPECT_THAT(result.status().error_message(),
3040 ::testing::HasSubstr(
3041 "The first operand must be a scalar of PRED or S32"));
3042 }
3043
3044 TEST_F(HloParserTest,
3045 CheckPredicatedConditionalRequiresTrueAndFalseComputation) {
3046 const char* const hlo_string = R"(
3047 HloModule Module
3048
3049 branch0 {
3050 tparam = f32[4] parameter(0)
3051 ROOT tgte1 = f32[4] ceil(tparam)
3052 }
3053
3054 branch1 {
3055 fparam = f32[4] parameter(0)
3056 ROOT fgte1 = f32[4] floor(fparam)
3057 }
3058
3059 ENTRY entry {
3060 p0 = f32[4] parameter(0)
3061 b0 = pred[] parameter(1)
3062 ROOT conditional = f32[4] conditional(b0, p0, p0),
3063 branch_computations={branch0, branch1}
3064 }
3065 )";
3066 auto result = ParseAndReturnUnverifiedModule(hlo_string);
3067 EXPECT_NE(Status::OK(), result.status());
3068 EXPECT_THAT(
3069 result.status().error_message(),
3070 ::testing::HasSubstr("unexpected attribute \"branch_computations\""));
3071 }
3072
3073 } // namespace
3074 } // namespace xla
3075