• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "gmock/gmock.h"
16 #include "src/ast/stage_decoration.h"
17 #include "src/ast/struct_block_decoration.h"
18 #include "src/ast/variable_decl_statement.h"
19 #include "src/ast/workgroup_decoration.h"
20 #include "src/writer/hlsl/test_helper.h"
21 
22 using ::testing::HasSubstr;
23 
24 namespace tint {
25 namespace writer {
26 namespace hlsl {
27 namespace {
28 
29 using HlslGeneratorImplTest_Function = TestHelper;
30 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function)31 TEST_F(HlslGeneratorImplTest_Function, Emit_Function) {
32   Func("my_func", ast::VariableList{}, ty.void_(),
33        {
34            Return(),
35        });
36 
37   GeneratorImpl& gen = Build();
38 
39   gen.increment_indent();
40 
41   ASSERT_TRUE(gen.Generate()) << gen.error();
42   EXPECT_EQ(gen.result(), R"(  void my_func() {
43     return;
44   }
45 )");
46 }
47 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_Name_Collision)48 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) {
49   Func("GeometryShader", ast::VariableList{}, ty.void_(),
50        {
51            Return(),
52        });
53 
54   GeneratorImpl& gen = SanitizeAndBuild();
55 
56   gen.increment_indent();
57 
58   ASSERT_TRUE(gen.Generate()) << gen.error();
59   EXPECT_THAT(gen.result(), HasSubstr(R"(  void tint_symbol() {
60     return;
61   })"));
62 }
63 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithParams)64 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) {
65   Func("my_func", ast::VariableList{Param("a", ty.f32()), Param("b", ty.i32())},
66        ty.void_(),
67        {
68            Return(),
69        });
70 
71   GeneratorImpl& gen = Build();
72 
73   gen.increment_indent();
74 
75   ASSERT_TRUE(gen.Generate()) << gen.error();
76   EXPECT_EQ(gen.result(), R"(  void my_func(float a, int b) {
77     return;
78   }
79 )");
80 }
81 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_NoReturn_Void)82 TEST_F(HlslGeneratorImplTest_Function,
83        Emit_Decoration_EntryPoint_NoReturn_Void) {
84   Func("main", ast::VariableList{}, ty.void_(), {/* no explicit return */},
85        {
86            Stage(ast::PipelineStage::kFragment),
87        });
88 
89   GeneratorImpl& gen = Build();
90 
91   ASSERT_TRUE(gen.Generate()) << gen.error();
92   EXPECT_EQ(gen.result(), R"(void main() {
93   return;
94 }
95 )");
96 }
97 
TEST_F(HlslGeneratorImplTest_Function,PtrParameter)98 TEST_F(HlslGeneratorImplTest_Function, PtrParameter) {
99   // fn f(foo : ptr<function, f32>) -> f32 {
100   //   return *foo;
101   // }
102   Func("f", {Param("foo", ty.pointer<f32>(ast::StorageClass::kFunction))},
103        ty.f32(), {Return(Deref("foo"))});
104 
105   GeneratorImpl& gen = SanitizeAndBuild();
106 
107   ASSERT_TRUE(gen.Generate()) << gen.error();
108   EXPECT_THAT(gen.result(), HasSubstr(R"(float f(inout float foo) {
109   return foo;
110 }
111 )"));
112 }
113 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_WithInOutVars)114 TEST_F(HlslGeneratorImplTest_Function,
115        Emit_Decoration_EntryPoint_WithInOutVars) {
116   // fn frag_main([[location(0)]] foo : f32) -> [[location(1)]] f32 {
117   //   return foo;
118   // }
119   auto* foo_in = Param("foo", ty.f32(), {Location(0)});
120   Func("frag_main", ast::VariableList{foo_in}, ty.f32(), {Return("foo")},
121        {Stage(ast::PipelineStage::kFragment)}, {Location(1)});
122 
123   GeneratorImpl& gen = SanitizeAndBuild();
124 
125   ASSERT_TRUE(gen.Generate()) << gen.error();
126   EXPECT_EQ(gen.result(), R"(struct tint_symbol_1 {
127   float foo : TEXCOORD0;
128 };
129 struct tint_symbol_2 {
130   float value : SV_Target1;
131 };
132 
133 float frag_main_inner(float foo) {
134   return foo;
135 }
136 
137 tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
138   const float inner_result = frag_main_inner(tint_symbol.foo);
139   tint_symbol_2 wrapper_result = (tint_symbol_2)0;
140   wrapper_result.value = inner_result;
141   return wrapper_result;
142 }
143 )");
144 }
145 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_WithInOut_Builtins)146 TEST_F(HlslGeneratorImplTest_Function,
147        Emit_Decoration_EntryPoint_WithInOut_Builtins) {
148   // fn frag_main([[position(0)]] coord : vec4<f32>) -> [[frag_depth]] f32 {
149   //   return coord.x;
150   // }
151   auto* coord_in =
152       Param("coord", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)});
153   Func("frag_main", ast::VariableList{coord_in}, ty.f32(),
154        {Return(MemberAccessor("coord", "x"))},
155        {Stage(ast::PipelineStage::kFragment)},
156        {Builtin(ast::Builtin::kFragDepth)});
157 
158   GeneratorImpl& gen = SanitizeAndBuild();
159 
160   ASSERT_TRUE(gen.Generate()) << gen.error();
161   EXPECT_EQ(gen.result(), R"(struct tint_symbol_1 {
162   float4 coord : SV_Position;
163 };
164 struct tint_symbol_2 {
165   float value : SV_Depth;
166 };
167 
168 float frag_main_inner(float4 coord) {
169   return coord.x;
170 }
171 
172 tint_symbol_2 frag_main(tint_symbol_1 tint_symbol) {
173   const float inner_result = frag_main_inner(tint_symbol.coord);
174   tint_symbol_2 wrapper_result = (tint_symbol_2)0;
175   wrapper_result.value = inner_result;
176   return wrapper_result;
177 }
178 )");
179 }
180 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_SharedStruct_DifferentStages)181 TEST_F(HlslGeneratorImplTest_Function,
182        Emit_Decoration_EntryPoint_SharedStruct_DifferentStages) {
183   // struct Interface {
184   //   [[builtin(position)]] pos : vec4<f32>;
185   //   [[location(1)]] col1 : f32;
186   //   [[location(2)]] col2 : f32;
187   // };
188   // fn vert_main() -> Interface {
189   //   return Interface(vec4<f32>(), 0.4, 0.6);
190   // }
191   // fn frag_main(inputs : Interface) {
192   //   const r = inputs.col1;
193   //   const g = inputs.col2;
194   //   const p = inputs.pos;
195   // }
196   auto* interface_struct = Structure(
197       "Interface",
198       {
199           Member("pos", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)}),
200           Member("col1", ty.f32(), {Location(1)}),
201           Member("col2", ty.f32(), {Location(2)}),
202       });
203 
204   Func("vert_main", {}, ty.Of(interface_struct),
205        {Return(Construct(ty.Of(interface_struct), Construct(ty.vec4<f32>()),
206                          Expr(0.5f), Expr(0.25f)))},
207        {Stage(ast::PipelineStage::kVertex)});
208 
209   Func("frag_main", {Param("inputs", ty.Of(interface_struct))}, ty.void_(),
210        {
211            Decl(Const("r", ty.f32(), MemberAccessor("inputs", "col1"))),
212            Decl(Const("g", ty.f32(), MemberAccessor("inputs", "col2"))),
213            Decl(Const("p", ty.vec4<f32>(), MemberAccessor("inputs", "pos"))),
214        },
215        {Stage(ast::PipelineStage::kFragment)});
216 
217   GeneratorImpl& gen = SanitizeAndBuild();
218 
219   ASSERT_TRUE(gen.Generate()) << gen.error();
220   EXPECT_EQ(gen.result(), R"(struct Interface {
221   float4 pos;
222   float col1;
223   float col2;
224 };
225 struct tint_symbol {
226   float col1 : TEXCOORD1;
227   float col2 : TEXCOORD2;
228   float4 pos : SV_Position;
229 };
230 
231 Interface vert_main_inner() {
232   const Interface tint_symbol_3 = {float4(0.0f, 0.0f, 0.0f, 0.0f), 0.5f, 0.25f};
233   return tint_symbol_3;
234 }
235 
236 tint_symbol vert_main() {
237   const Interface inner_result = vert_main_inner();
238   tint_symbol wrapper_result = (tint_symbol)0;
239   wrapper_result.pos = inner_result.pos;
240   wrapper_result.col1 = inner_result.col1;
241   wrapper_result.col2 = inner_result.col2;
242   return wrapper_result;
243 }
244 
245 struct tint_symbol_2 {
246   float col1 : TEXCOORD1;
247   float col2 : TEXCOORD2;
248   float4 pos : SV_Position;
249 };
250 
251 void frag_main_inner(Interface inputs) {
252   const float r = inputs.col1;
253   const float g = inputs.col2;
254   const float4 p = inputs.pos;
255 }
256 
257 void frag_main(tint_symbol_2 tint_symbol_1) {
258   const Interface tint_symbol_4 = {tint_symbol_1.pos, tint_symbol_1.col1, tint_symbol_1.col2};
259   frag_main_inner(tint_symbol_4);
260   return;
261 }
262 )");
263 }
264 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_SharedStruct_HelperFunction)265 TEST_F(HlslGeneratorImplTest_Function,
266        Emit_Decoration_EntryPoint_SharedStruct_HelperFunction) {
267   // struct VertexOutput {
268   //   [[builtin(position)]] pos : vec4<f32>;
269   // };
270   // fn foo(x : f32) -> VertexOutput {
271   //   return VertexOutput(vec4<f32>(x, x, x, 1.0));
272   // }
273   // fn vert_main1() -> VertexOutput {
274   //   return foo(0.5);
275   // }
276   // fn vert_main2() -> VertexOutput {
277   //   return foo(0.25);
278   // }
279   auto* vertex_output_struct = Structure(
280       "VertexOutput",
281       {Member("pos", ty.vec4<f32>(), {Builtin(ast::Builtin::kPosition)})});
282 
283   Func("foo", {Param("x", ty.f32())}, ty.Of(vertex_output_struct),
284        {Return(Construct(ty.Of(vertex_output_struct),
285                          Construct(ty.vec4<f32>(), "x", "x", "x", Expr(1.f))))},
286        {});
287 
288   Func("vert_main1", {}, ty.Of(vertex_output_struct),
289        {Return(Call("foo", Expr(0.5f)))}, {Stage(ast::PipelineStage::kVertex)});
290 
291   Func("vert_main2", {}, ty.Of(vertex_output_struct),
292        {Return(Call("foo", Expr(0.25f)))},
293        {Stage(ast::PipelineStage::kVertex)});
294 
295   GeneratorImpl& gen = SanitizeAndBuild();
296 
297   ASSERT_TRUE(gen.Generate()) << gen.error();
298   EXPECT_EQ(gen.result(), R"(struct VertexOutput {
299   float4 pos;
300 };
301 
302 VertexOutput foo(float x) {
303   const VertexOutput tint_symbol_2 = {float4(x, x, x, 1.0f)};
304   return tint_symbol_2;
305 }
306 
307 struct tint_symbol {
308   float4 pos : SV_Position;
309 };
310 
311 VertexOutput vert_main1_inner() {
312   return foo(0.5f);
313 }
314 
315 tint_symbol vert_main1() {
316   const VertexOutput inner_result = vert_main1_inner();
317   tint_symbol wrapper_result = (tint_symbol)0;
318   wrapper_result.pos = inner_result.pos;
319   return wrapper_result;
320 }
321 
322 struct tint_symbol_1 {
323   float4 pos : SV_Position;
324 };
325 
326 VertexOutput vert_main2_inner() {
327   return foo(0.25f);
328 }
329 
330 tint_symbol_1 vert_main2() {
331   const VertexOutput inner_result_1 = vert_main2_inner();
332   tint_symbol_1 wrapper_result_1 = (tint_symbol_1)0;
333   wrapper_result_1.pos = inner_result_1.pos;
334   return wrapper_result_1;
335 }
336 )");
337 }
338 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_Uniform)339 TEST_F(HlslGeneratorImplTest_Function,
340        Emit_Decoration_EntryPoint_With_Uniform) {
341   auto* ubo_ty = Structure("UBO", {Member("coord", ty.vec4<f32>())},
342                            {create<ast::StructBlockDecoration>()});
343   auto* ubo = Global("ubo", ty.Of(ubo_ty), ast::StorageClass::kUniform,
344                      ast::DecorationList{
345                          create<ast::BindingDecoration>(0),
346                          create<ast::GroupDecoration>(1),
347                      });
348 
349   Func("sub_func",
350        {
351            Param("param", ty.f32()),
352        },
353        ty.f32(),
354        {
355            Return(MemberAccessor(MemberAccessor(ubo, "coord"), "x")),
356        });
357 
358   auto* var =
359       Var("v", ty.f32(), ast::StorageClass::kNone, Call("sub_func", 1.0f));
360 
361   Func("frag_main", {}, ty.void_(),
362        {
363            Decl(var),
364            Return(),
365        },
366        {
367            Stage(ast::PipelineStage::kFragment),
368        });
369 
370   GeneratorImpl& gen = SanitizeAndBuild();
371 
372   ASSERT_TRUE(gen.Generate()) << gen.error();
373   EXPECT_EQ(gen.result(), R"(cbuffer cbuffer_ubo : register(b0, space1) {
374   uint4 ubo[1];
375 };
376 
377 float sub_func(float param) {
378   return asfloat(ubo[0].x);
379 }
380 
381 void frag_main() {
382   float v = sub_func(1.0f);
383   return;
384 }
385 )");
386 }
387 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_UniformStruct)388 TEST_F(HlslGeneratorImplTest_Function,
389        Emit_Decoration_EntryPoint_With_UniformStruct) {
390   auto* s = Structure("Uniforms", {Member("coord", ty.vec4<f32>())},
391                       {create<ast::StructBlockDecoration>()});
392 
393   Global("uniforms", ty.Of(s), ast::StorageClass::kUniform,
394          ast::DecorationList{
395              create<ast::BindingDecoration>(0),
396              create<ast::GroupDecoration>(1),
397          });
398 
399   auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
400                   MemberAccessor(MemberAccessor("uniforms", "coord"), "x"));
401 
402   Func("frag_main", ast::VariableList{}, ty.void_(),
403        {
404            Decl(var),
405            Return(),
406        },
407        {
408            Stage(ast::PipelineStage::kFragment),
409        });
410 
411   GeneratorImpl& gen = Build();
412 
413   ASSERT_TRUE(gen.Generate()) << gen.error();
414   EXPECT_EQ(gen.result(), R"(cbuffer cbuffer_uniforms : register(b0, space1) {
415   uint4 uniforms[1];
416 };
417 
418 void frag_main() {
419   float v = uniforms.coord.x;
420   return;
421 }
422 )");
423 }
424 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_RW_StorageBuffer_Read)425 TEST_F(HlslGeneratorImplTest_Function,
426        Emit_Decoration_EntryPoint_With_RW_StorageBuffer_Read) {
427   auto* s = Structure("Data",
428                       {
429                           Member("a", ty.i32()),
430                           Member("b", ty.f32()),
431                       },
432                       {create<ast::StructBlockDecoration>()});
433 
434   Global("coord", ty.Of(s), ast::StorageClass::kStorage,
435          ast::Access::kReadWrite,
436          ast::DecorationList{
437              create<ast::BindingDecoration>(0),
438              create<ast::GroupDecoration>(1),
439          });
440 
441   auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
442                   MemberAccessor("coord", "b"));
443 
444   Func("frag_main", ast::VariableList{}, ty.void_(),
445        {
446            Decl(var),
447            Return(),
448        },
449        {
450            Stage(ast::PipelineStage::kFragment),
451        });
452 
453   GeneratorImpl& gen = SanitizeAndBuild();
454 
455   ASSERT_TRUE(gen.Generate()) << gen.error();
456   EXPECT_EQ(gen.result(),
457             R"(RWByteAddressBuffer coord : register(u0, space1);
458 
459 void frag_main() {
460   float v = asfloat(coord.Load(4u));
461   return;
462 }
463 )");
464 }
465 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_RO_StorageBuffer_Read)466 TEST_F(HlslGeneratorImplTest_Function,
467        Emit_Decoration_EntryPoint_With_RO_StorageBuffer_Read) {
468   auto* s = Structure("Data",
469                       {
470                           Member("a", ty.i32()),
471                           Member("b", ty.f32()),
472                       },
473                       {create<ast::StructBlockDecoration>()});
474 
475   Global("coord", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kRead,
476          ast::DecorationList{
477              create<ast::BindingDecoration>(0),
478              create<ast::GroupDecoration>(1),
479          });
480 
481   auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
482                   MemberAccessor("coord", "b"));
483 
484   Func("frag_main", ast::VariableList{}, ty.void_(),
485        {
486            Decl(var),
487            Return(),
488        },
489        {
490            Stage(ast::PipelineStage::kFragment),
491        });
492 
493   GeneratorImpl& gen = SanitizeAndBuild();
494 
495   ASSERT_TRUE(gen.Generate()) << gen.error();
496   EXPECT_EQ(gen.result(),
497             R"(ByteAddressBuffer coord : register(t0, space1);
498 
499 void frag_main() {
500   float v = asfloat(coord.Load(4u));
501   return;
502 }
503 )");
504 }
505 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_WO_StorageBuffer_Store)506 TEST_F(HlslGeneratorImplTest_Function,
507        Emit_Decoration_EntryPoint_With_WO_StorageBuffer_Store) {
508   auto* s = Structure("Data",
509                       {
510                           Member("a", ty.i32()),
511                           Member("b", ty.f32()),
512                       },
513                       {create<ast::StructBlockDecoration>()});
514 
515   Global("coord", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kWrite,
516          ast::DecorationList{
517              create<ast::BindingDecoration>(0),
518              create<ast::GroupDecoration>(1),
519          });
520 
521   Func("frag_main", ast::VariableList{}, ty.void_(),
522        {
523            Assign(MemberAccessor("coord", "b"), Expr(2.0f)),
524            Return(),
525        },
526        {
527            Stage(ast::PipelineStage::kFragment),
528        });
529 
530   GeneratorImpl& gen = SanitizeAndBuild();
531 
532   ASSERT_TRUE(gen.Generate()) << gen.error();
533   EXPECT_EQ(gen.result(),
534             R"(RWByteAddressBuffer coord : register(u0, space1);
535 
536 void frag_main() {
537   coord.Store(4u, asuint(2.0f));
538   return;
539 }
540 )");
541 }
542 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_With_StorageBuffer_Store)543 TEST_F(HlslGeneratorImplTest_Function,
544        Emit_Decoration_EntryPoint_With_StorageBuffer_Store) {
545   auto* s = Structure("Data",
546                       {
547                           Member("a", ty.i32()),
548                           Member("b", ty.f32()),
549                       },
550                       {create<ast::StructBlockDecoration>()});
551 
552   Global("coord", ty.Of(s), ast::StorageClass::kStorage,
553          ast::Access::kReadWrite,
554          ast::DecorationList{
555              create<ast::BindingDecoration>(0),
556              create<ast::GroupDecoration>(1),
557          });
558 
559   Func("frag_main", ast::VariableList{}, ty.void_(),
560        {
561            Assign(MemberAccessor("coord", "b"), Expr(2.0f)),
562            Return(),
563        },
564        {
565            Stage(ast::PipelineStage::kFragment),
566        });
567 
568   GeneratorImpl& gen = SanitizeAndBuild();
569 
570   ASSERT_TRUE(gen.Generate()) << gen.error();
571   EXPECT_EQ(gen.result(),
572             R"(RWByteAddressBuffer coord : register(u0, space1);
573 
574 void frag_main() {
575   coord.Store(4u, asuint(2.0f));
576   return;
577 }
578 )");
579 }
580 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_Called_By_EntryPoint_With_Uniform)581 TEST_F(HlslGeneratorImplTest_Function,
582        Emit_Decoration_Called_By_EntryPoint_With_Uniform) {
583   auto* s = Structure("S", {Member("x", ty.f32())},
584                       {create<ast::StructBlockDecoration>()});
585   Global("coord", ty.Of(s), ast::StorageClass::kUniform,
586          ast::DecorationList{
587              create<ast::BindingDecoration>(0),
588              create<ast::GroupDecoration>(1),
589          });
590 
591   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
592        {
593            Return(MemberAccessor("coord", "x")),
594        });
595 
596   auto* var =
597       Var("v", ty.f32(), ast::StorageClass::kNone, Call("sub_func", 1.0f));
598 
599   Func("frag_main", ast::VariableList{}, ty.void_(),
600        {
601            Decl(var),
602            Return(),
603        },
604        {
605            Stage(ast::PipelineStage::kFragment),
606        });
607 
608   GeneratorImpl& gen = Build();
609 
610   ASSERT_TRUE(gen.Generate()) << gen.error();
611   EXPECT_EQ(gen.result(), R"(cbuffer cbuffer_coord : register(b0, space1) {
612   uint4 coord[1];
613 };
614 
615 float sub_func(float param) {
616   return coord.x;
617 }
618 
619 void frag_main() {
620   float v = sub_func(1.0f);
621   return;
622 }
623 )");
624 }
625 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_Called_By_EntryPoint_With_StorageBuffer)626 TEST_F(HlslGeneratorImplTest_Function,
627        Emit_Decoration_Called_By_EntryPoint_With_StorageBuffer) {
628   auto* s = Structure("S", {Member("x", ty.f32())},
629                       {create<ast::StructBlockDecoration>()});
630   Global("coord", ty.Of(s), ast::StorageClass::kStorage,
631          ast::Access::kReadWrite,
632          ast::DecorationList{
633              create<ast::BindingDecoration>(0),
634              create<ast::GroupDecoration>(1),
635          });
636 
637   Func("sub_func", ast::VariableList{Param("param", ty.f32())}, ty.f32(),
638        {
639            Return(MemberAccessor("coord", "x")),
640        });
641 
642   auto* var =
643       Var("v", ty.f32(), ast::StorageClass::kNone, Call("sub_func", 1.0f));
644 
645   Func("frag_main", ast::VariableList{}, ty.void_(),
646        {
647            Decl(var),
648            Return(),
649        },
650        {
651            Stage(ast::PipelineStage::kFragment),
652        });
653 
654   GeneratorImpl& gen = SanitizeAndBuild();
655 
656   ASSERT_TRUE(gen.Generate()) << gen.error();
657   EXPECT_EQ(gen.result(),
658             R"(RWByteAddressBuffer coord : register(u0, space1);
659 
660 float sub_func(float param) {
661   return asfloat(coord.Load(0u));
662 }
663 
664 void frag_main() {
665   float v = sub_func(1.0f);
666   return;
667 }
668 )");
669 }
670 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_WithNameCollision)671 TEST_F(HlslGeneratorImplTest_Function,
672        Emit_Decoration_EntryPoint_WithNameCollision) {
673   Func("GeometryShader", ast::VariableList{}, ty.void_(), {},
674        {
675            Stage(ast::PipelineStage::kFragment),
676        });
677 
678   GeneratorImpl& gen = SanitizeAndBuild();
679 
680   ASSERT_TRUE(gen.Generate()) << gen.error();
681   EXPECT_EQ(gen.result(), R"(void tint_symbol() {
682   return;
683 }
684 )");
685 }
686 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute)687 TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_Compute) {
688   Func("main", ast::VariableList{}, ty.void_(),
689        {
690            Return(),
691        },
692        {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
693 
694   GeneratorImpl& gen = Build();
695 
696   ASSERT_TRUE(gen.Generate()) << gen.error();
697   EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
698 void main() {
699   return;
700 }
701 )");
702 }
703 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal)704 TEST_F(HlslGeneratorImplTest_Function,
705        Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Literal) {
706   Func("main", ast::VariableList{}, ty.void_(), {},
707        {
708            Stage(ast::PipelineStage::kCompute),
709            WorkgroupSize(2, 4, 6),
710        });
711 
712   GeneratorImpl& gen = Build();
713 
714   ASSERT_TRUE(gen.Generate()) << gen.error();
715   EXPECT_EQ(gen.result(), R"([numthreads(2, 4, 6)]
716 void main() {
717   return;
718 }
719 )");
720 }
721 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const)722 TEST_F(HlslGeneratorImplTest_Function,
723        Emit_Decoration_EntryPoint_Compute_WithWorkgroup_Const) {
724   GlobalConst("width", ty.i32(), Construct(ty.i32(), 2));
725   GlobalConst("height", ty.i32(), Construct(ty.i32(), 3));
726   GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4));
727   Func("main", ast::VariableList{}, ty.void_(), {},
728        {
729            Stage(ast::PipelineStage::kCompute),
730            WorkgroupSize("width", "height", "depth"),
731        });
732 
733   GeneratorImpl& gen = Build();
734 
735   ASSERT_TRUE(gen.Generate()) << gen.error();
736   EXPECT_EQ(gen.result(), R"(static const int width = int(2);
737 static const int height = int(3);
738 static const int depth = int(4);
739 
740 [numthreads(2, 3, 4)]
741 void main() {
742   return;
743 }
744 )");
745 }
746 
TEST_F(HlslGeneratorImplTest_Function,Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst)747 TEST_F(HlslGeneratorImplTest_Function,
748        Emit_Decoration_EntryPoint_Compute_WithWorkgroup_OverridableConst) {
749   GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)});
750   GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)});
751   GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)});
752   Func("main", ast::VariableList{}, ty.void_(), {},
753        {
754            Stage(ast::PipelineStage::kCompute),
755            WorkgroupSize("width", "height", "depth"),
756        });
757 
758   GeneratorImpl& gen = Build();
759 
760   ASSERT_TRUE(gen.Generate()) << gen.error();
761   EXPECT_EQ(gen.result(), R"(#ifndef WGSL_SPEC_CONSTANT_7
762 #define WGSL_SPEC_CONSTANT_7 int(2)
763 #endif
764 static const int width = WGSL_SPEC_CONSTANT_7;
765 #ifndef WGSL_SPEC_CONSTANT_8
766 #define WGSL_SPEC_CONSTANT_8 int(3)
767 #endif
768 static const int height = WGSL_SPEC_CONSTANT_8;
769 #ifndef WGSL_SPEC_CONSTANT_9
770 #define WGSL_SPEC_CONSTANT_9 int(4)
771 #endif
772 static const int depth = WGSL_SPEC_CONSTANT_9;
773 
774 [numthreads(WGSL_SPEC_CONSTANT_7, WGSL_SPEC_CONSTANT_8, WGSL_SPEC_CONSTANT_9)]
775 void main() {
776   return;
777 }
778 )");
779 }
780 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithArrayParams)781 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
782   Func("my_func", ast::VariableList{Param("a", ty.array<f32, 5>())}, ty.void_(),
783        {
784            Return(),
785        });
786 
787   GeneratorImpl& gen = Build();
788 
789   ASSERT_TRUE(gen.Generate()) << gen.error();
790   EXPECT_EQ(gen.result(), R"(void my_func(float a[5]) {
791   return;
792 }
793 )");
794 }
795 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithArrayReturn)796 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayReturn) {
797   Func("my_func", {}, ty.array<f32, 5>(),
798        {
799            Return(Construct(ty.array<f32, 5>())),
800        });
801 
802   GeneratorImpl& gen = Build();
803 
804   ASSERT_TRUE(gen.Generate()) << gen.error();
805   EXPECT_EQ(gen.result(), R"(typedef float my_func_ret[5];
806 my_func_ret my_func() {
807   return (float[5])0;
808 }
809 )");
810 }
811 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithDiscardAndVoidReturn)812 TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithDiscardAndVoidReturn) {
813   Func("my_func", {Param("a", ty.i32())}, ty.void_(),
814        {
815            If(Equal("a", 0),  //
816               Block(create<ast::DiscardStatement>())),
817            Return(),
818        });
819 
820   GeneratorImpl& gen = Build();
821 
822   ASSERT_TRUE(gen.Generate()) << gen.error();
823   EXPECT_EQ(gen.result(), R"(void my_func(int a) {
824   if ((a == 0)) {
825     discard;
826   }
827   return;
828 }
829 )");
830 }
831 
TEST_F(HlslGeneratorImplTest_Function,Emit_Function_WithDiscardAndNonVoidReturn)832 TEST_F(HlslGeneratorImplTest_Function,
833        Emit_Function_WithDiscardAndNonVoidReturn) {
834   Func("my_func", {Param("a", ty.i32())}, ty.i32(),
835        {
836            If(Equal("a", 0),  //
837               Block(create<ast::DiscardStatement>())),
838            Return(42),
839        });
840 
841   GeneratorImpl& gen = Build();
842 
843   ASSERT_TRUE(gen.Generate()) << gen.error();
844   EXPECT_EQ(gen.result(), R"(int my_func(int a) {
845   if (true) {
846     if ((a == 0)) {
847       discard;
848     }
849     return 42;
850   }
851   int unused;
852   return unused;
853 }
854 )");
855 }
856 
857 // https://crbug.com/tint/297
TEST_F(HlslGeneratorImplTest_Function,Emit_Multiple_EntryPoint_With_Same_ModuleVar)858 TEST_F(HlslGeneratorImplTest_Function,
859        Emit_Multiple_EntryPoint_With_Same_ModuleVar) {
860   // [[block]] struct Data {
861   //   d : f32;
862   // };
863   // [[binding(0), group(0)]] var<storage> data : Data;
864   //
865   // [[stage(compute), workgroup_size(1)]]
866   // fn a() {
867   //   var v = data.d;
868   //   return;
869   // }
870   //
871   // [[stage(compute), workgroup_size(1)]]
872   // fn b() {
873   //   var v = data.d;
874   //   return;
875   // }
876 
877   auto* s = Structure("Data", {Member("d", ty.f32())},
878                       {create<ast::StructBlockDecoration>()});
879 
880   Global("data", ty.Of(s), ast::StorageClass::kStorage, ast::Access::kReadWrite,
881          ast::DecorationList{
882              create<ast::BindingDecoration>(0),
883              create<ast::GroupDecoration>(0),
884          });
885 
886   {
887     auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
888                     MemberAccessor("data", "d"));
889 
890     Func("a", ast::VariableList{}, ty.void_(),
891          {
892              Decl(var),
893              Return(),
894          },
895          {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
896   }
897 
898   {
899     auto* var = Var("v", ty.f32(), ast::StorageClass::kNone,
900                     MemberAccessor("data", "d"));
901 
902     Func("b", ast::VariableList{}, ty.void_(),
903          {
904              Decl(var),
905              Return(),
906          },
907          {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
908   }
909 
910   GeneratorImpl& gen = SanitizeAndBuild();
911 
912   ASSERT_TRUE(gen.Generate()) << gen.error();
913   EXPECT_EQ(gen.result(), R"(RWByteAddressBuffer data : register(u0, space0);
914 
915 [numthreads(1, 1, 1)]
916 void a() {
917   float v = asfloat(data.Load(0u));
918   return;
919 }
920 
921 [numthreads(1, 1, 1)]
922 void b() {
923   float v = asfloat(data.Load(0u));
924   return;
925 }
926 )");
927 }
928 
929 }  // namespace
930 }  // namespace hlsl
931 }  // namespace writer
932 }  // namespace tint
933