1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "gmock/gmock.h"
16 #include "src/ast/stage_decoration.h"
17 #include "src/ast/struct_block_decoration.h"
18 #include "src/ast/variable_decl_statement.h"
19 #include "src/ast/workgroup_decoration.h"
20 #include "src/writer/hlsl/test_helper.h"
21
22 using ::testing::HasSubstr;
23
24 namespace tint {
25 namespace writer {
26 namespace hlsl {
27 namespace {
28
29 using HlslGeneratorImplTest_Function = TestHelper;
30
TEST_F(HlslGeneratorImplTest_Function,Emit_Function)31 TEST_F(HlslGeneratorImplTest_Function, Emit_Function) {
32 Func("my_func", ast::VariableList{}, ty.void_(),
33 {
34 Return(),
35 });
36
37 GeneratorImpl& gen = Build();
38
39 gen.increment_indent();
40
41 ASSERT_TRUE(gen.Generate()) << gen.error();
42 EXPECT_EQ(gen.result(), R"( void my_func() {
43 return;
44 }
45 )");
46 }
47
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_Name_Collision)48 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) {
49 Func("GeometryShader", ast::VariableList{}, ty.void_(),
50 {
51 Return(),
52 });
53
54 GeneratorImpl& gen = SanitizeAndBuild();
55
56 gen.increment_indent();
57
58 ASSERT_TRUE(gen.Generate()) << gen.error();
59 EXPECT_THAT(gen.result(), HasSubstr(R"( void tint_symbol() {
60 return;
61 })"));
62 }
63
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithParams)64 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) {
65 Func("my_func", ast::VariableList{Param("a", ty.f32()), Param("b", ty.i32())},
66 ty.void_(),
67 {
68 Return(),
69 });
70
71 GeneratorImpl& gen = Build();
72
73 gen.increment_indent();
74
75 ASSERT_TRUE(gen.Generate()) << gen.error();
76 EXPECT_EQ(gen.result(), R"( void my_func(float a, int b) {
77 return;
78 }
79 )");
80 }
81
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_NoReturn_Void)82 TEST_F(HlslGeneratorImplTest_Function,
83 Emit_Decoration_EntryPoint_NoReturn_Void) {
84 Func("main", ast::VariableList{}, ty.void_(), {/* no explicit return */},
85 {
86 Stage(ast::PipelineStage::kFragment),
87 });
88
89 GeneratorImpl& gen = Build();
90
91 ASSERT_TRUE(gen.Generate()) << gen.error();
92 EXPECT_EQ(gen.result(), R"(void main() {
93 return;
94 }
95 )");
96 }
97
TEST_F(HlslGeneratorImplTest_Function,PtrParameter)98 TEST_F(HlslGeneratorImplTest_Function, PtrParameter) {
99 // fn f(foo : ptr<function, f32>) -> f32 {
100 // return *foo;
101 // }
102 Func("f", {Param("foo", ty.pointer<f32>(ast::StorageClass::kFunction))},
103 ty.f32(), {Return(Deref("foo"))});
104
105 GeneratorImpl& gen = SanitizeAndBuild();
106
107 ASSERT_TRUE(gen.Generate()) << gen.error();
108 EXPECT_THAT(gen.result(), HasSubstr(R"(float f(inout float foo) {
109 return foo;
110 }
111 )"));
112 }
113
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_WithInOutVars)114 TEST_F(HlslGeneratorImplTest_Function,
115 Emit_Decoration_EntryPoint_WithInOutVars) {
116 // fn frag_main([[location(0)]] foo : f32) -> [[location(1)]] f32 {
117 // return foo;
118 // }
119 auto* foo_in = Param("foo", ty.f32(), {Location(0)});
120 Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {Return("foo")},
121 {Stage(ast::PipelineStage::kFragment)}, {Location(1)});
122
123 GeneratorImpl& gen = SanitizeAndBuild();
124
125 ASSERT_TRUE(gen.Generate()) << gen.error();
126 EXPECT_EQ(gen.result(), R"(struct tint_symbol_1 {
127 float foo : TEXCOORD0;
128 };
129 struct tint_symbol_2 {
130 float value : SV_Target1;
131 };
132
133 float frag_main_inner(float foo) {
134 return foo;
135 }
136
137 tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
138 const float inner_result = frag_main_inner(tint_symbol.foo);
139 tint_symbol_2 wrapper_result = (tint_symbol_2)0;
140 wrapper_result.value = inner_result;
141 return wrapper_result;
142 }
143 )");
144 }
145
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_WithInOut_Builtins)146 TEST_F(HlslGeneratorImplTest_Function,
147 Emit_Decoration_EntryPoint_WithInOut_Builtins) {
148 // fn frag_main([[position(0)]] coord : vec4<f32>) -> [[frag_depth]] f32 {
149 // return coord.x;
150 // }
151 auto* coord_in =
152 Param("coord", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)});
153 Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
154 {Return(MemberAccessor("coord", "x"))},
155 {Stage(ast::PipelineStage::kFragment)},
156 {Builtin(ast::Builtin::kFragDepth)});
157
158 GeneratorImpl& gen = SanitizeAndBuild();
159
160 ASSERT_TRUE(gen.Generate()) << gen.error();
161 EXPECT_EQ(gen.result(), R"(struct tint_symbol_1 {
162 float4 coord : SV_Position;
163 };
164 struct tint_symbol_2 {
165 float value : SV_Depth;
166 };
167
168 float frag_main_inner(float4 coord) {
169 return coord.x;
170 }
171
172 tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
173 const float inner_result = frag_main_inner(tint_symbol.coord);
174 tint_symbol_2 wrapper_result = (tint_symbol_2)0;
175 wrapper_result.value = inner_result;
176 return wrapper_result;
177 }
178 )");
179 }
180
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_SharedStruct_DifferentStages)181 TEST_F(HlslGeneratorImplTest_Function,
182 Emit_Decoration_EntryPoint_SharedStruct_DifferentStages) {
183 // struct Interface {
184 // [[builtin(position)]] pos : vec4<f32>;
185 // [[location(1)]] col1 : f32;
186 // [[location(2)]] col2 : f32;
187 // };
188 // fn vert_main() -> Interface {
189 // return Interface(vec4<f32>(), 0.4, 0.6);
190 // }
191 // fn frag_main(inputs : Interface) {
192 // const r = inputs.col1;
193 // const g = inputs.col2;
194 // const p = inputs.pos;
195 // }
196 auto* interface_struct = Structure(
197 "Interface",
198 {
199 Member("pos", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)}),
200 Member("col1", ty.f32(), {Location(1)}),
201 Member("col2", ty.f32(), {Location(2)}),
202 });
203
204 Func("vert_main", {}, ty.Of(interface_struct),
205 {Return(Construct(ty.Of(interface_struct), Construct(ty.vec4<f32>()),
206 Expr(0.5f), Expr(0.25f)))},
207 {Stage(ast::PipelineStage::kVertex)});
208
209 Func("frag_main", {Param("inputs", ty.Of(interface_struct))}, ty.void_(),
210 {
211 Decl(Const("r", ty.f32(), MemberAccessor("inputs", "col1"))),
212 Decl(Const("g", ty.f32(), MemberAccessor("inputs", "col2"))),
213 Decl(Const("p", ty.vec4<f32>(), MemberAccessor("inputs", "pos"))),
214 },
215 {Stage(ast::PipelineStage::kFragment)});
216
217 GeneratorImpl& gen = SanitizeAndBuild();
218
219 ASSERT_TRUE(gen.Generate()) << gen.error();
220 EXPECT_EQ(gen.result(), R"(struct Interface {
221 float4 pos;
222 float col1;
223 float col2;
224 };
225 struct tint_symbol {
226 float col1 : TEXCOORD1;
227 float col2 : TEXCOORD2;
228 float4 pos : SV_Position;
229 };
230
231 Interface vert_main_inner() {
232 const Interface tint_symbol_3 = {float4(0.0f, 0.0f, 0.0f, 0.0f), 0.5f, 0.25f};
233 return tint_symbol_3;
234 }
235
236 tint_symbol vert_main() {
237 const Interface inner_result = vert_main_inner();
238 tint_symbol wrapper_result = (tint_symbol)0;
239 wrapper_result.pos = inner_result.pos;
240 wrapper_result.col1 = inner_result.col1;
241 wrapper_result.col2 = inner_result.col2;
242 return wrapper_result;
243 }
244
245 struct tint_symbol_2 {
246 float col1 : TEXCOORD1;
247 float col2 : TEXCOORD2;
248 float4 pos : SV_Position;
249 };
250
251 void frag_main_inner(Interface inputs) {
252 const float r = inputs.col1;
253 const float g = inputs.col2;
254 const float4 p = inputs.pos;
255 }
256
257 void frag_main(tint_symbol_2 tint_symbol_1) {
258 const Interface tint_symbol_4 = {tint_symbol_1.pos, tint_symbol_1.col1, tint_symbol_1.col2};
259 frag_main_inner(tint_symbol_4);
260 return;
261 }
262 )");
263 }
264
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_SharedStruct_HelperFunction)265 TEST_F(HlslGeneratorImplTest_Function,
266 Emit_Decoration_EntryPoint_SharedStruct_HelperFunction) {
267 // struct VertexOutput {
268 // [[builtin(position)]] pos : vec4<f32>;
269 // };
270 // fn foo(x : f32) -> VertexOutput {
271 // return VertexOutput(vec4<f32>(x, x, x, 1.0));
272 // }
273 // fn vert_main1() -> VertexOutput {
274 // return foo(0.5);
275 // }
276 // fn vert_main2() -> VertexOutput {
277 // return foo(0.25);
278 // }
279 auto* vertex_output_struct = Structure(
280 "VertexOutput",
281 {Member("pos", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)})});
282
283 Func("foo", {Param("x", ty.f32())}, ty.Of(vertex_output_struct),
284 {Return(Construct(ty.Of(vertex_output_struct),
285 Construct(ty.vec4<f32>(), "x", "x", "x", Expr(1.f))))},
286 {});
287
288 Func("vert_main1", {}, ty.Of(vertex_output_struct),
289 {Return(Call("foo", Expr(0.5f)))}, {Stage(ast::PipelineStage::kVertex)});
290
291 Func("vert_main2", {}, ty.Of(vertex_output_struct),
292 {Return(Call("foo", Expr(0.25f)))},
293 {Stage(ast::PipelineStage::kVertex)});
294
295 GeneratorImpl& gen = SanitizeAndBuild();
296
297 ASSERT_TRUE(gen.Generate()) << gen.error();
298 EXPECT_EQ(gen.result(), R"(struct VertexOutput {
299 float4 pos;
300 };
301
302 VertexOutput foo(float x) {
303 const VertexOutput tint_symbol_2 = {float4(x, x, x, 1.0f)};
304 return tint_symbol_2;
305 }
306
307 struct tint_symbol {
308 float4 pos : SV_Position;
309 };
310
311 VertexOutput vert_main1_inner() {
312 return foo(0.5f);
313 }
314
315 tint_symbol vert_main1() {
316 const VertexOutput inner_result = vert_main1_inner();
317 tint_symbol wrapper_result = (tint_symbol)0;
318 wrapper_result.pos = inner_result.pos;
319 return wrapper_result;
320 }
321
322 struct tint_symbol_1 {
323 float4 pos : SV_Position;
324 };
325
326 VertexOutput vert_main2_inner() {
327 return foo(0.25f);
328 }
329
330 tint_symbol_1 vert_main2() {
331 const VertexOutput inner_result_1 = vert_main2_inner();
332 tint_symbol_1 wrapper_result_1 = (tint_symbol_1)0;
333 wrapper_result_1.pos = inner_result_1.pos;
334 return wrapper_result_1;
335 }
336 )");
337 }
338
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_Uniform)339 TEST_F(HlslGeneratorImplTest_Function,
340 Emit_Decoration_EntryPoint_With_Uniform) {
341 auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
342 {create<ast::StructBlockDecoration>()});
343 auto* ubo = Global("ubo", ty.Of(ubo_ty), ast::StorageClass::kUniform,
344 ast::DecorationList{
345 create<ast::BindingDecoration>(0),
346 create<ast::GroupDecoration>(1),
347 });
348
349 Func("sub_func",
350 {
351 Param("param", ty.f32()),
352 },
353 ty.f32(),
354 {
355 Return(MemberAccessor(MemberAccessor(ubo, "coord"), "x")),
356 });
357
358 auto* var =
359 Var("v", ty.f32(), ast::StorageClass::kNone, Call("sub_func", 1.0f));
360
361 Func("frag_main", {}, ty.void_(),
362 {
363 Decl(var),
364 Return(),
365 },
366 {
367 Stage(ast::PipelineStage::kFragment),
368 });
369
370 GeneratorImpl& gen = SanitizeAndBuild();
371
372 ASSERT_TRUE(gen.Generate()) << gen.error();
373 EXPECT_EQ(gen.result(), R"(cbuffer cbuffer_ubo : register(b0, space1) {
374 uint4 ubo[1];
375 };
376
377 float sub_func(float param) {
378 return asfloat(ubo[0].x);
379 }
380
381 void frag_main() {
382 float v = sub_func(1.0f);
383 return;
384 }
385 )");
386 }
387
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_UniformStruct)388 TEST_F(HlslGeneratorImplTest_Function,
389 Emit_Decoration_EntryPoint_With_UniformStruct) {
390 auto* s = Structure("Uniforms", {Member("coord", ty.vec4<f32>())},
391 {create<ast::StructBlockDecoration>()});
392
393 Global("uniforms", ty.Of(s), ast::StorageClass::kUniform,
394 ast::DecorationList{
395 create<ast::BindingDecoration>(0),
396 create<ast::GroupDecoration>(1),
397 });
398
399 auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
400 MemberAccessor(MemberAccessor("uniforms", "coord"), "x"));
401
402 Func("frag_main", ast::VariableList{}, ty.void_(),
403 {
404 Decl(var),
405 Return(),
406 },
407 {
408 Stage(ast::PipelineStage::kFragment),
409 });
410
411 GeneratorImpl& gen = Build();
412
413 ASSERT_TRUE(gen.Generate()) << gen.error();
414 EXPECT_EQ(gen.result(), R"(cbuffer cbuffer_uniforms : register(b0, space1) {
415 uint4 uniforms[1];
416 };
417
418 void frag_main() {
419 float v = uniforms.coord.x;
420 return;
421 }
422 )");
423 }
424
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_RW_StorageBuffer_Read)425 TEST_F(HlslGeneratorImplTest_Function,
426 Emit_Decoration_EntryPoint_With_RW_StorageBuffer_Read) {
427 auto* s = Structure("Data",
428 {
429 Member("a", ty.i32()),
430 Member("b", ty.f32()),
431 },
432 {create<ast::StructBlockDecoration>()});
433
434 Global("coord", ty.Of(s), ast::StorageClass::kStorage,
435 ast::Access::kReadWrite,
436 ast::DecorationList{
437 create<ast::BindingDecoration>(0),
438 create<ast::GroupDecoration>(1),
439 });
440
441 auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
442 MemberAccessor("coord", "b"));
443
444 Func("frag_main", ast::VariableList{}, ty.void_(),
445 {
446 Decl(var),
447 Return(),
448 },
449 {
450 Stage(ast::PipelineStage::kFragment),
451 });
452
453 GeneratorImpl& gen = SanitizeAndBuild();
454
455 ASSERT_TRUE(gen.Generate()) << gen.error();
456 EXPECT_EQ(gen.result(),
457 R"(RWByteAddressBuffer coord : register(u0, space1);
458
459 void frag_main() {
460 float v = asfloat(coord.Load(4u));
461 return;
462 }
463 )");
464 }
465
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_RO_StorageBuffer_Read)466 TEST_F(HlslGeneratorImplTest_Function,
467 Emit_Decoration_EntryPoint_With_RO_StorageBuffer_Read) {
468 auto* s = Structure("Data",
469 {
470 Member("a", ty.i32()),
471 Member("b", ty.f32()),
472 },
473 {create<ast::StructBlockDecoration>()});
474
475 Global("coord", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
476 ast::DecorationList{
477 create<ast::BindingDecoration>(0),
478 create<ast::GroupDecoration>(1),
479 });
480
481 auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
482 MemberAccessor("coord", "b"));
483
484 Func("frag_main", ast::VariableList{}, ty.void_(),
485 {
486 Decl(var),
487 Return(),
488 },
489 {
490 Stage(ast::PipelineStage::kFragment),
491 });
492
493 GeneratorImpl& gen = SanitizeAndBuild();
494
495 ASSERT_TRUE(gen.Generate()) << gen.error();
496 EXPECT_EQ(gen.result(),
497 R"(ByteAddressBuffer coord : register(t0, space1);
498
499 void frag_main() {
500 float v = asfloat(coord.Load(4u));
501 return;
502 }
503 )");
504 }
505
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_WO_StorageBuffer_Store)506 TEST_F(HlslGeneratorImplTest_Function,
507 Emit_Decoration_EntryPoint_With_WO_StorageBuffer_Store) {
508 auto* s = Structure("Data",
509 {
510 Member("a", ty.i32()),
511 Member("b", ty.f32()),
512 },
513 {create<ast::StructBlockDecoration>()});
514
515 Global("coord", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kWrite,
516 ast::DecorationList{
517 create<ast::BindingDecoration>(0),
518 create<ast::GroupDecoration>(1),
519 });
520
521 Func("frag_main", ast::VariableList{}, ty.void_(),
522 {
523 Assign(MemberAccessor("coord", "b"), Expr(2.0f)),
524 Return(),
525 },
526 {
527 Stage(ast::PipelineStage::kFragment),
528 });
529
530 GeneratorImpl& gen = SanitizeAndBuild();
531
532 ASSERT_TRUE(gen.Generate()) << gen.error();
533 EXPECT_EQ(gen.result(),
534 R"(RWByteAddressBuffer coord : register(u0, space1);
535
536 void frag_main() {
537 coord.Store(4u, asuint(2.0f));
538 return;
539 }
540 )");
541 }
542
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_StorageBuffer_Store)543 TEST_F(HlslGeneratorImplTest_Function,
544 Emit_Decoration_EntryPoint_With_StorageBuffer_Store) {
545 auto* s = Structure("Data",
546 {
547 Member("a", ty.i32()),
548 Member("b", ty.f32()),
549 },
550 {create<ast::StructBlockDecoration>()});
551
552 Global("coord", ty.Of(s), ast::StorageClass::kStorage,
553 ast::Access::kReadWrite,
554 ast::DecorationList{
555 create<ast::BindingDecoration>(0),
556 create<ast::GroupDecoration>(1),
557 });
558
559 Func("frag_main", ast::VariableList{}, ty.void_(),
560 {
561 Assign(MemberAccessor("coord", "b"), Expr(2.0f)),
562 Return(),
563 },
564 {
565 Stage(ast::PipelineStage::kFragment),
566 });
567
568 GeneratorImpl& gen = SanitizeAndBuild();
569
570 ASSERT_TRUE(gen.Generate()) << gen.error();
571 EXPECT_EQ(gen.result(),
572 R"(RWByteAddressBuffer coord : register(u0, space1);
573
574 void frag_main() {
575 coord.Store(4u, asuint(2.0f));
576 return;
577 }
578 )");
579 }
580
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_Called_By_EntryPoint_With_Uniform)581 TEST_F(HlslGeneratorImplTest_Function,
582 Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
583 auto* s = Structure("S", {Member("x", ty.f32())},
584 {create<ast::StructBlockDecoration>()});
585 Global("coord", ty.Of(s), ast::StorageClass::kUniform,
586 ast::DecorationList{
587 create<ast::BindingDecoration>(0),
588 create<ast::GroupDecoration>(1),
589 });
590
591 Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
592 {
593 Return(MemberAccessor("coord", "x")),
594 });
595
596 auto* var =
597 Var("v", ty.f32(), ast::StorageClass::kNone, Call("sub_func", 1.0f));
598
599 Func("frag_main", ast::VariableList{}, ty.void_(),
600 {
601 Decl(var),
602 Return(),
603 },
604 {
605 Stage(ast::PipelineStage::kFragment),
606 });
607
608 GeneratorImpl& gen = Build();
609
610 ASSERT_TRUE(gen.Generate()) << gen.error();
611 EXPECT_EQ(gen.result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
612 uint4 coord[1];
613 };
614
615 float sub_func(float param) {
616 return coord.x;
617 }
618
619 void frag_main() {
620 float v = sub_func(1.0f);
621 return;
622 }
623 )");
624 }
625
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_Called_By_EntryPoint_With_StorageBuffer)626 TEST_F(HlslGeneratorImplTest_Function,
627 Emit_Decoration_Called_By_EntryPoint_With_StorageBuffer) {
628 auto* s = Structure("S", {Member("x", ty.f32())},
629 {create<ast::StructBlockDecoration>()});
630 Global("coord", ty.Of(s), ast::StorageClass::kStorage,
631 ast::Access::kReadWrite,
632 ast::DecorationList{
633 create<ast::BindingDecoration>(0),
634 create<ast::GroupDecoration>(1),
635 });
636
637 Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
638 {
639 Return(MemberAccessor("coord", "x")),
640 });
641
642 auto* var =
643 Var("v", ty.f32(), ast::StorageClass::kNone, Call("sub_func", 1.0f));
644
645 Func("frag_main", ast::VariableList{}, ty.void_(),
646 {
647 Decl(var),
648 Return(),
649 },
650 {
651 Stage(ast::PipelineStage::kFragment),
652 });
653
654 GeneratorImpl& gen = SanitizeAndBuild();
655
656 ASSERT_TRUE(gen.Generate()) << gen.error();
657 EXPECT_EQ(gen.result(),
658 R"(RWByteAddressBuffer coord : register(u0, space1);
659
660 float sub_func(float param) {
661 return asfloat(coord.Load(0u));
662 }
663
664 void frag_main() {
665 float v = sub_func(1.0f);
666 return;
667 }
668 )");
669 }
670
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_WithNameCollision)671 TEST_F(HlslGeneratorImplTest_Function,
672 Emit_Decoration_EntryPoint_WithNameCollision) {
673 Func("GeometryShader", ast::VariableList{}, ty.void_(), {},
674 {
675 Stage(ast::PipelineStage::kFragment),
676 });
677
678 GeneratorImpl& gen = SanitizeAndBuild();
679
680 ASSERT_TRUE(gen.Generate()) << gen.error();
681 EXPECT_EQ(gen.result(), R"(void tint_symbol() {
682 return;
683 }
684 )");
685 }
686
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute)687 TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_Compute) {
688 Func("main", ast::VariableList{}, ty.void_(),
689 {
690 Return(),
691 },
692 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
693
694 GeneratorImpl& gen = Build();
695
696 ASSERT_TRUE(gen.Generate()) << gen.error();
697 EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
698 void main() {
699 return;
700 }
701 )");
702 }
703
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal)704 TEST_F(HlslGeneratorImplTest_Function,
705 Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal) {
706 Func("main", ast::VariableList{}, ty.void_(), {},
707 {
708 Stage(ast::PipelineStage::kCompute),
709 WorkgroupSize(2, 4, 6),
710 });
711
712 GeneratorImpl& gen = Build();
713
714 ASSERT_TRUE(gen.Generate()) << gen.error();
715 EXPECT_EQ(gen.result(), R"([numthreads(2, 4, 6)]
716 void main() {
717 return;
718 }
719 )");
720 }
721
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const)722 TEST_F(HlslGeneratorImplTest_Function,
723 Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const) {
724 GlobalConst("width", ty.i32(), Construct(ty.i32(), 2));
725 GlobalConst("height", ty.i32(), Construct(ty.i32(), 3));
726 GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4));
727 Func("main", ast::VariableList{}, ty.void_(), {},
728 {
729 Stage(ast::PipelineStage::kCompute),
730 WorkgroupSize("width", "height", "depth"),
731 });
732
733 GeneratorImpl& gen = Build();
734
735 ASSERT_TRUE(gen.Generate()) << gen.error();
736 EXPECT_EQ(gen.result(), R"(static const int width = int(2);
737 static const int height = int(3);
738 static const int depth = int(4);
739
740 [numthreads(2, 3, 4)]
741 void main() {
742 return;
743 }
744 )");
745 }
746
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst)747 TEST_F(HlslGeneratorImplTest_Function,
748 Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
749 GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)});
750 GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)});
751 GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)});
752 Func("main", ast::VariableList{}, ty.void_(), {},
753 {
754 Stage(ast::PipelineStage::kCompute),
755 WorkgroupSize("width", "height", "depth"),
756 });
757
758 GeneratorImpl& gen = Build();
759
760 ASSERT_TRUE(gen.Generate()) << gen.error();
761 EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_7
762 #define WGSL_SPEC_CONSTANT_7 int(2)
763 #endif
764 static const int width = WGSL_SPEC_CONSTANT_7;
765 #ifndef WGSL_SPEC_CONSTANT_8
766 #define WGSL_SPEC_CONSTANT_8 int(3)
767 #endif
768 static const int height = WGSL_SPEC_CONSTANT_8;
769 #ifndef WGSL_SPEC_CONSTANT_9
770 #define WGSL_SPEC_CONSTANT_9 int(4)
771 #endif
772 static const int depth = WGSL_SPEC_CONSTANT_9;
773
774 [numthreads(WGSL_SPEC_CONSTANT_7, WGSL_SPEC_CONSTANT_8, WGSL_SPEC_CONSTANT_9)]
775 void main() {
776 return;
777 }
778 )");
779 }
780
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithArrayParams)781 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
782 Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
783 {
784 Return(),
785 });
786
787 GeneratorImpl& gen = Build();
788
789 ASSERT_TRUE(gen.Generate()) << gen.error();
790 EXPECT_EQ(gen.result(), R"(void my_func(float a[5]) {
791 return;
792 }
793 )");
794 }
795
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithArrayReturn)796 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayReturn) {
797 Func("my_func", {}, ty.array<f32, 5>(),
798 {
799 Return(Construct(ty.array<f32, 5>())),
800 });
801
802 GeneratorImpl& gen = Build();
803
804 ASSERT_TRUE(gen.Generate()) << gen.error();
805 EXPECT_EQ(gen.result(), R"(typedef float my_func_ret[5];
806 my_func_ret my_func() {
807 return (float[5])0;
808 }
809 )");
810 }
811
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithDiscardAndVoidReturn)812 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithDiscardAndVoidReturn) {
813 Func("my_func", {Param("a", ty.i32())}, ty.void_(),
814 {
815 If(Equal("a", 0), //
816 Block(create<ast::DiscardStatement>())),
817 Return(),
818 });
819
820 GeneratorImpl& gen = Build();
821
822 ASSERT_TRUE(gen.Generate()) << gen.error();
823 EXPECT_EQ(gen.result(), R"(void my_func(int a) {
824 if ((a == 0)) {
825 discard;
826 }
827 return;
828 }
829 )");
830 }
831
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithDiscardAndNonVoidReturn)832 TEST_F(HlslGeneratorImplTest_Function,
833 Emit_Function_WithDiscardAndNonVoidReturn) {
834 Func("my_func", {Param("a", ty.i32())}, ty.i32(),
835 {
836 If(Equal("a", 0), //
837 Block(create<ast::DiscardStatement>())),
838 Return(42),
839 });
840
841 GeneratorImpl& gen = Build();
842
843 ASSERT_TRUE(gen.Generate()) << gen.error();
844 EXPECT_EQ(gen.result(), R"(int my_func(int a) {
845 if (true) {
846 if ((a == 0)) {
847 discard;
848 }
849 return 42;
850 }
851 int unused;
852 return unused;
853 }
854 )");
855 }
856
857 // https://crbug.com/tint/297
TEST_F(HlslGeneratorImplTest_Function,Emit_Multiple_EntryPoint_With_Same_ModuleVar)858 TEST_F(HlslGeneratorImplTest_Function,
859 Emit_Multiple_EntryPoint_With_Same_ModuleVar) {
860 // [[block]] struct Data {
861 // d : f32;
862 // };
863 // [[binding(0), group(0)]] var<storage> data : Data;
864 //
865 // [[stage(compute), workgroup_size(1)]]
866 // fn a() {
867 // var v = data.d;
868 // return;
869 // }
870 //
871 // [[stage(compute), workgroup_size(1)]]
872 // fn b() {
873 // var v = data.d;
874 // return;
875 // }
876
877 auto* s = Structure("Data", {Member("d", ty.f32())},
878 {create<ast::StructBlockDecoration>()});
879
880 Global("data", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kReadWrite,
881 ast::DecorationList{
882 create<ast::BindingDecoration>(0),
883 create<ast::GroupDecoration>(0),
884 });
885
886 {
887 auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
888 MemberAccessor("data", "d"));
889
890 Func("a", ast::VariableList{}, ty.void_(),
891 {
892 Decl(var),
893 Return(),
894 },
895 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
896 }
897
898 {
899 auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
900 MemberAccessor("data", "d"));
901
902 Func("b", ast::VariableList{}, ty.void_(),
903 {
904 Decl(var),
905 Return(),
906 },
907 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
908 }
909
910 GeneratorImpl& gen = SanitizeAndBuild();
911
912 ASSERT_TRUE(gen.Generate()) << gen.error();
913 EXPECT_EQ(gen.result(), R"(RWByteAddressBuffer data : register(u0, space0);
914
915 [numthreads(1, 1, 1)]
916 void a() {
917 float v = asfloat(data.Load(0u));
918 return;
919 }
920
921 [numthreads(1, 1, 1)]
922 void b() {
923 float v = asfloat(data.Load(0u));
924 return;
925 }
926 )");
927 }
928
929 } // namespace
930 } // namespace hlsl
931 } // namespace writer
932 } // namespace tint
933