1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/resolver/resolver.h"
16
17 #include <tuple>
18
19 #include "gmock/gmock.h"
20 #include "gtest/gtest-spi.h"
21 #include "src/ast/assignment_statement.h"
22 #include "src/ast/bitcast_expression.h"
23 #include "src/ast/break_statement.h"
24 #include "src/ast/call_statement.h"
25 #include "src/ast/continue_statement.h"
26 #include "src/ast/float_literal_expression.h"
27 #include "src/ast/if_statement.h"
28 #include "src/ast/intrinsic_texture_helper_test.h"
29 #include "src/ast/loop_statement.h"
30 #include "src/ast/override_decoration.h"
31 #include "src/ast/return_statement.h"
32 #include "src/ast/stage_decoration.h"
33 #include "src/ast/struct_block_decoration.h"
34 #include "src/ast/switch_statement.h"
35 #include "src/ast/unary_op_expression.h"
36 #include "src/ast/variable_decl_statement.h"
37 #include "src/ast/workgroup_decoration.h"
38 #include "src/resolver/resolver_test_helper.h"
39 #include "src/sem/call.h"
40 #include "src/sem/function.h"
41 #include "src/sem/member_accessor_expression.h"
42 #include "src/sem/reference_type.h"
43 #include "src/sem/sampled_texture_type.h"
44 #include "src/sem/statement.h"
45 #include "src/sem/variable.h"
46
47 using ::testing::ElementsAre;
48 using ::testing::HasSubstr;
49
50 namespace tint {
51 namespace resolver {
52 namespace {
53
54 // Helpers and typedefs
55 template <typename T>
56 using DataType = builder::DataType<T>;
57 template <int N, typename T>
58 using vec = builder::vec<N, T>;
59 template <typename T>
60 using vec2 = builder::vec2<T>;
61 template <typename T>
62 using vec3 = builder::vec3<T>;
63 template <typename T>
64 using vec4 = builder::vec4<T>;
65 template <int N, int M, typename T>
66 using mat = builder::mat<N, M, T>;
67 template <typename T>
68 using mat2x2 = builder::mat2x2<T>;
69 template <typename T>
70 using mat2x3 = builder::mat2x3<T>;
71 template <typename T>
72 using mat3x2 = builder::mat3x2<T>;
73 template <typename T>
74 using mat3x3 = builder::mat3x3<T>;
75 template <typename T>
76 using mat4x4 = builder::mat4x4<T>;
77 template <typename T, int ID = 0>
78 using alias = builder::alias<T, ID>;
79 template <typename T>
80 using alias1 = builder::alias1<T>;
81 template <typename T>
82 using alias2 = builder::alias2<T>;
83 template <typename T>
84 using alias3 = builder::alias3<T>;
85 using f32 = builder::f32;
86 using i32 = builder::i32;
87 using u32 = builder::u32;
88 using Op = ast::BinaryOp;
89
TEST_F(ResolverTest,Stmt_Assign)90 TEST_F(ResolverTest, Stmt_Assign) {
91 auto* v = Var("v", ty.f32());
92 auto* lhs = Expr("v");
93 auto* rhs = Expr(2.3f);
94
95 auto* assign = Assign(lhs, rhs);
96 WrapInFunction(v, assign);
97
98 EXPECT_TRUE(r()->Resolve()) << r()->error();
99
100 ASSERT_NE(TypeOf(lhs), nullptr);
101 ASSERT_NE(TypeOf(rhs), nullptr);
102
103 EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
104 EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
105 EXPECT_EQ(StmtOf(lhs), assign);
106 EXPECT_EQ(StmtOf(rhs), assign);
107 }
108
TEST_F(ResolverTest,Stmt_Case)109 TEST_F(ResolverTest, Stmt_Case) {
110 auto* v = Var("v", ty.f32());
111 auto* lhs = Expr("v");
112 auto* rhs = Expr(2.3f);
113
114 auto* assign = Assign(lhs, rhs);
115 auto* block = Block(assign);
116 ast::CaseSelectorList lit;
117 lit.push_back(create<ast::SintLiteralExpression>(3));
118 auto* cse = create<ast::CaseStatement>(lit, block);
119 auto* cond_var = Var("c", ty.i32());
120 auto* sw = Switch(cond_var, cse, DefaultCase());
121 WrapInFunction(v, cond_var, sw);
122
123 EXPECT_TRUE(r()->Resolve()) << r()->error();
124
125 ASSERT_NE(TypeOf(lhs), nullptr);
126 ASSERT_NE(TypeOf(rhs), nullptr);
127 EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
128 EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
129 EXPECT_EQ(StmtOf(lhs), assign);
130 EXPECT_EQ(StmtOf(rhs), assign);
131 EXPECT_EQ(BlockOf(assign), block);
132 }
133
TEST_F(ResolverTest,Stmt_Block)134 TEST_F(ResolverTest, Stmt_Block) {
135 auto* v = Var("v", ty.f32());
136 auto* lhs = Expr("v");
137 auto* rhs = Expr(2.3f);
138
139 auto* assign = Assign(lhs, rhs);
140 auto* block = Block(assign);
141 WrapInFunction(v, block);
142
143 EXPECT_TRUE(r()->Resolve()) << r()->error();
144
145 ASSERT_NE(TypeOf(lhs), nullptr);
146 ASSERT_NE(TypeOf(rhs), nullptr);
147 EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
148 EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
149 EXPECT_EQ(StmtOf(lhs), assign);
150 EXPECT_EQ(StmtOf(rhs), assign);
151 EXPECT_EQ(BlockOf(lhs), block);
152 EXPECT_EQ(BlockOf(rhs), block);
153 EXPECT_EQ(BlockOf(assign), block);
154 }
155
TEST_F(ResolverTest,Stmt_If)156 TEST_F(ResolverTest, Stmt_If) {
157 auto* v = Var("v", ty.f32());
158 auto* else_lhs = Expr("v");
159 auto* else_rhs = Expr(2.3f);
160
161 auto* else_body = Block(Assign(else_lhs, else_rhs));
162
163 auto* else_cond = Expr(true);
164 auto* else_stmt = create<ast::ElseStatement>(else_cond, else_body);
165
166 auto* lhs = Expr("v");
167 auto* rhs = Expr(2.3f);
168
169 auto* assign = Assign(lhs, rhs);
170 auto* body = Block(assign);
171 auto* cond = Expr(true);
172 auto* stmt =
173 create<ast::IfStatement>(cond, body, ast::ElseStatementList{else_stmt});
174 WrapInFunction(v, stmt);
175
176 EXPECT_TRUE(r()->Resolve()) << r()->error();
177
178 ASSERT_NE(TypeOf(stmt->condition), nullptr);
179 ASSERT_NE(TypeOf(else_lhs), nullptr);
180 ASSERT_NE(TypeOf(else_rhs), nullptr);
181 ASSERT_NE(TypeOf(lhs), nullptr);
182 ASSERT_NE(TypeOf(rhs), nullptr);
183 EXPECT_TRUE(TypeOf(stmt->condition)->Is<sem::Bool>());
184 EXPECT_TRUE(TypeOf(else_lhs)->UnwrapRef()->Is<sem::F32>());
185 EXPECT_TRUE(TypeOf(else_rhs)->Is<sem::F32>());
186 EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
187 EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
188 EXPECT_EQ(StmtOf(lhs), assign);
189 EXPECT_EQ(StmtOf(rhs), assign);
190 EXPECT_EQ(StmtOf(cond), stmt);
191 EXPECT_EQ(StmtOf(else_cond), else_stmt);
192 EXPECT_EQ(BlockOf(lhs), body);
193 EXPECT_EQ(BlockOf(rhs), body);
194 EXPECT_EQ(BlockOf(else_lhs), else_body);
195 EXPECT_EQ(BlockOf(else_rhs), else_body);
196 }
197
TEST_F(ResolverTest,Stmt_Loop)198 TEST_F(ResolverTest, Stmt_Loop) {
199 auto* v = Var("v", ty.f32());
200 auto* body_lhs = Expr("v");
201 auto* body_rhs = Expr(2.3f);
202
203 auto* body = Block(Assign(body_lhs, body_rhs));
204 auto* continuing_lhs = Expr("v");
205 auto* continuing_rhs = Expr(2.3f);
206
207 auto* continuing = Block(Assign(continuing_lhs, continuing_rhs));
208 auto* stmt = Loop(body, continuing);
209 WrapInFunction(v, stmt);
210
211 EXPECT_TRUE(r()->Resolve()) << r()->error();
212
213 ASSERT_NE(TypeOf(body_lhs), nullptr);
214 ASSERT_NE(TypeOf(body_rhs), nullptr);
215 ASSERT_NE(TypeOf(continuing_lhs), nullptr);
216 ASSERT_NE(TypeOf(continuing_rhs), nullptr);
217 EXPECT_TRUE(TypeOf(body_lhs)->UnwrapRef()->Is<sem::F32>());
218 EXPECT_TRUE(TypeOf(body_rhs)->Is<sem::F32>());
219 EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapRef()->Is<sem::F32>());
220 EXPECT_TRUE(TypeOf(continuing_rhs)->Is<sem::F32>());
221 EXPECT_EQ(BlockOf(body_lhs), body);
222 EXPECT_EQ(BlockOf(body_rhs), body);
223 EXPECT_EQ(BlockOf(continuing_lhs), continuing);
224 EXPECT_EQ(BlockOf(continuing_rhs), continuing);
225 }
226
TEST_F(ResolverTest,Stmt_Return)227 TEST_F(ResolverTest, Stmt_Return) {
228 auto* cond = Expr(2);
229
230 auto* ret = Return(cond);
231 Func("test", {}, ty.i32(), {ret}, {});
232
233 EXPECT_TRUE(r()->Resolve()) << r()->error();
234
235 ASSERT_NE(TypeOf(cond), nullptr);
236 EXPECT_TRUE(TypeOf(cond)->Is<sem::I32>());
237 }
238
TEST_F(ResolverTest,Stmt_Return_WithoutValue)239 TEST_F(ResolverTest, Stmt_Return_WithoutValue) {
240 auto* ret = Return();
241 WrapInFunction(ret);
242
243 EXPECT_TRUE(r()->Resolve()) << r()->error();
244 }
245
TEST_F(ResolverTest,Stmt_Switch)246 TEST_F(ResolverTest, Stmt_Switch) {
247 auto* v = Var("v", ty.f32());
248 auto* lhs = Expr("v");
249 auto* rhs = Expr(2.3f);
250 auto* case_block = Block(Assign(lhs, rhs));
251 auto* stmt = Switch(Expr(2), Case(Expr(3), case_block), DefaultCase());
252 WrapInFunction(v, stmt);
253
254 EXPECT_TRUE(r()->Resolve()) << r()->error();
255
256 ASSERT_NE(TypeOf(stmt->condition), nullptr);
257 ASSERT_NE(TypeOf(lhs), nullptr);
258 ASSERT_NE(TypeOf(rhs), nullptr);
259
260 EXPECT_TRUE(TypeOf(stmt->condition)->Is<sem::I32>());
261 EXPECT_TRUE(TypeOf(lhs)->UnwrapRef()->Is<sem::F32>());
262 EXPECT_TRUE(TypeOf(rhs)->Is<sem::F32>());
263 EXPECT_EQ(BlockOf(lhs), case_block);
264 EXPECT_EQ(BlockOf(rhs), case_block);
265 }
266
TEST_F(ResolverTest,Stmt_Call)267 TEST_F(ResolverTest, Stmt_Call) {
268 ast::VariableList params;
269 Func("my_func", params, ty.void_(), {Return()}, ast::DecorationList{});
270
271 auto* expr = Call("my_func");
272
273 auto* call = CallStmt(expr);
274 WrapInFunction(call);
275
276 EXPECT_TRUE(r()->Resolve()) << r()->error();
277
278 ASSERT_NE(TypeOf(expr), nullptr);
279 EXPECT_TRUE(TypeOf(expr)->Is<sem::Void>());
280 EXPECT_EQ(StmtOf(expr), call);
281 }
282
TEST_F(ResolverTest,Stmt_VariableDecl)283 TEST_F(ResolverTest, Stmt_VariableDecl) {
284 auto* var = Var("my_var", ty.i32(), ast::StorageClass::kNone, Expr(2));
285 auto* init = var->constructor;
286
287 auto* decl = Decl(var);
288 WrapInFunction(decl);
289
290 EXPECT_TRUE(r()->Resolve()) << r()->error();
291
292 ASSERT_NE(TypeOf(init), nullptr);
293 EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
294 }
295
TEST_F(ResolverTest,Stmt_VariableDecl_Alias)296 TEST_F(ResolverTest, Stmt_VariableDecl_Alias) {
297 auto* my_int = Alias("MyInt", ty.i32());
298 auto* var = Var("my_var", ty.Of(my_int), ast::StorageClass::kNone, Expr(2));
299 auto* init = var->constructor;
300
301 auto* decl = Decl(var);
302 WrapInFunction(decl);
303
304 EXPECT_TRUE(r()->Resolve()) << r()->error();
305
306 ASSERT_NE(TypeOf(init), nullptr);
307 EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
308 }
309
TEST_F(ResolverTest,Stmt_VariableDecl_ModuleScope)310 TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScope) {
311 auto* init = Expr(2);
312 Global("my_var", ty.i32(), ast::StorageClass::kPrivate, init);
313
314 EXPECT_TRUE(r()->Resolve()) << r()->error();
315
316 ASSERT_NE(TypeOf(init), nullptr);
317 EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
318 EXPECT_EQ(StmtOf(init), nullptr);
319 }
320
TEST_F(ResolverTest,Stmt_VariableDecl_OuterScopeAfterInnerScope)321 TEST_F(ResolverTest, Stmt_VariableDecl_OuterScopeAfterInnerScope) {
322 // fn func_i32() {
323 // {
324 // var foo : i32 = 2;
325 // var bar : i32 = foo;
326 // }
327 // var foo : f32 = 2.0;
328 // var bar : f32 = foo;
329 // }
330
331 ast::VariableList params;
332
333 // Declare i32 "foo" inside a block
334 auto* foo_i32 = Var("foo", ty.i32(), ast::StorageClass::kNone, Expr(2));
335 auto* foo_i32_init = foo_i32->constructor;
336 auto* foo_i32_decl = Decl(foo_i32);
337
338 // Reference "foo" inside the block
339 auto* bar_i32 = Var("bar", ty.i32(), ast::StorageClass::kNone, Expr("foo"));
340 auto* bar_i32_init = bar_i32->constructor;
341 auto* bar_i32_decl = Decl(bar_i32);
342
343 auto* inner = Block(foo_i32_decl, bar_i32_decl);
344
345 // Declare f32 "foo" at function scope
346 auto* foo_f32 = Var("foo", ty.f32(), ast::StorageClass::kNone, Expr(2.f));
347 auto* foo_f32_init = foo_f32->constructor;
348 auto* foo_f32_decl = Decl(foo_f32);
349
350 // Reference "foo" at function scope
351 auto* bar_f32 = Var("bar", ty.f32(), ast::StorageClass::kNone, Expr("foo"));
352 auto* bar_f32_init = bar_f32->constructor;
353 auto* bar_f32_decl = Decl(bar_f32);
354
355 Func("func", params, ty.void_(), {inner, foo_f32_decl, bar_f32_decl},
356 ast::DecorationList{});
357
358 EXPECT_TRUE(r()->Resolve()) << r()->error();
359 ASSERT_NE(TypeOf(foo_i32_init), nullptr);
360 EXPECT_TRUE(TypeOf(foo_i32_init)->Is<sem::I32>());
361 ASSERT_NE(TypeOf(foo_f32_init), nullptr);
362 EXPECT_TRUE(TypeOf(foo_f32_init)->Is<sem::F32>());
363 ASSERT_NE(TypeOf(bar_i32_init), nullptr);
364 EXPECT_TRUE(TypeOf(bar_i32_init)->UnwrapRef()->Is<sem::I32>());
365 ASSERT_NE(TypeOf(bar_f32_init), nullptr);
366 EXPECT_TRUE(TypeOf(bar_f32_init)->UnwrapRef()->Is<sem::F32>());
367 EXPECT_EQ(StmtOf(foo_i32_init), foo_i32_decl);
368 EXPECT_EQ(StmtOf(bar_i32_init), bar_i32_decl);
369 EXPECT_EQ(StmtOf(foo_f32_init), foo_f32_decl);
370 EXPECT_EQ(StmtOf(bar_f32_init), bar_f32_decl);
371 EXPECT_TRUE(CheckVarUsers(foo_i32, {bar_i32->constructor}));
372 EXPECT_TRUE(CheckVarUsers(foo_f32, {bar_f32->constructor}));
373 ASSERT_NE(VarOf(bar_i32->constructor), nullptr);
374 EXPECT_EQ(VarOf(bar_i32->constructor)->Declaration(), foo_i32);
375 ASSERT_NE(VarOf(bar_f32->constructor), nullptr);
376 EXPECT_EQ(VarOf(bar_f32->constructor)->Declaration(), foo_f32);
377 }
378
TEST_F(ResolverTest,Stmt_VariableDecl_ModuleScopeAfterFunctionScope)379 TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScopeAfterFunctionScope) {
380 // fn func_i32() {
381 // var foo : i32 = 2;
382 // }
383 // var foo : f32 = 2.0;
384 // fn func_f32() {
385 // var bar : f32 = foo;
386 // }
387
388 ast::VariableList params;
389
390 // Declare i32 "foo" inside a function
391 auto* fn_i32 = Var("foo", ty.i32(), ast::StorageClass::kNone, Expr(2));
392 auto* fn_i32_init = fn_i32->constructor;
393 auto* fn_i32_decl = Decl(fn_i32);
394 Func("func_i32", params, ty.void_(), {fn_i32_decl}, ast::DecorationList{});
395
396 // Declare f32 "foo" at module scope
397 auto* mod_f32 = Var("foo", ty.f32(), ast::StorageClass::kPrivate, Expr(2.f));
398 auto* mod_init = mod_f32->constructor;
399 AST().AddGlobalVariable(mod_f32);
400
401 // Reference "foo" in another function
402 auto* fn_f32 = Var("bar", ty.f32(), ast::StorageClass::kNone, Expr("foo"));
403 auto* fn_f32_init = fn_f32->constructor;
404 auto* fn_f32_decl = Decl(fn_f32);
405 Func("func_f32", params, ty.void_(), {fn_f32_decl}, ast::DecorationList{});
406
407 EXPECT_TRUE(r()->Resolve()) << r()->error();
408 ASSERT_NE(TypeOf(mod_init), nullptr);
409 EXPECT_TRUE(TypeOf(mod_init)->Is<sem::F32>());
410 ASSERT_NE(TypeOf(fn_i32_init), nullptr);
411 EXPECT_TRUE(TypeOf(fn_i32_init)->Is<sem::I32>());
412 ASSERT_NE(TypeOf(fn_f32_init), nullptr);
413 EXPECT_TRUE(TypeOf(fn_f32_init)->UnwrapRef()->Is<sem::F32>());
414 EXPECT_EQ(StmtOf(fn_i32_init), fn_i32_decl);
415 EXPECT_EQ(StmtOf(mod_init), nullptr);
416 EXPECT_EQ(StmtOf(fn_f32_init), fn_f32_decl);
417 EXPECT_TRUE(CheckVarUsers(fn_i32, {}));
418 EXPECT_TRUE(CheckVarUsers(mod_f32, {fn_f32->constructor}));
419 ASSERT_NE(VarOf(fn_f32->constructor), nullptr);
420 EXPECT_EQ(VarOf(fn_f32->constructor)->Declaration(), mod_f32);
421 }
422
TEST_F(ResolverTest,ArraySize_UnsignedLiteral)423 TEST_F(ResolverTest, ArraySize_UnsignedLiteral) {
424 // var<private> a : array<f32, 10u>;
425 auto* a =
426 Global("a", ty.array(ty.f32(), Expr(10u)), ast::StorageClass::kPrivate);
427
428 EXPECT_TRUE(r()->Resolve()) << r()->error();
429
430 ASSERT_NE(TypeOf(a), nullptr);
431 auto* ref = TypeOf(a)->As<sem::Reference>();
432 ASSERT_NE(ref, nullptr);
433 auto* ary = ref->StoreType()->As<sem::Array>();
434 EXPECT_EQ(ary->Count(), 10u);
435 }
436
TEST_F(ResolverTest,ArraySize_SignedLiteral)437 TEST_F(ResolverTest, ArraySize_SignedLiteral) {
438 // var<private> a : array<f32, 10>;
439 auto* a =
440 Global("a", ty.array(ty.f32(), Expr(10)), ast::StorageClass::kPrivate);
441
442 EXPECT_TRUE(r()->Resolve()) << r()->error();
443
444 ASSERT_NE(TypeOf(a), nullptr);
445 auto* ref = TypeOf(a)->As<sem::Reference>();
446 ASSERT_NE(ref, nullptr);
447 auto* ary = ref->StoreType()->As<sem::Array>();
448 EXPECT_EQ(ary->Count(), 10u);
449 }
450
TEST_F(ResolverTest,ArraySize_UnsignedConstant)451 TEST_F(ResolverTest, ArraySize_UnsignedConstant) {
452 // let size = 0u;
453 // var<private> a : array<f32, 10u>;
454 GlobalConst("size", nullptr, Expr(10u));
455 auto* a = Global("a", ty.array(ty.f32(), Expr("size")),
456 ast::StorageClass::kPrivate);
457
458 EXPECT_TRUE(r()->Resolve()) << r()->error();
459
460 ASSERT_NE(TypeOf(a), nullptr);
461 auto* ref = TypeOf(a)->As<sem::Reference>();
462 ASSERT_NE(ref, nullptr);
463 auto* ary = ref->StoreType()->As<sem::Array>();
464 EXPECT_EQ(ary->Count(), 10u);
465 }
466
TEST_F(ResolverTest,ArraySize_SignedConstant)467 TEST_F(ResolverTest, ArraySize_SignedConstant) {
468 // let size = 0;
469 // var<private> a : array<f32, 10>;
470 GlobalConst("size", nullptr, Expr(10));
471 auto* a = Global("a", ty.array(ty.f32(), Expr("size")),
472 ast::StorageClass::kPrivate);
473
474 EXPECT_TRUE(r()->Resolve()) << r()->error();
475
476 ASSERT_NE(TypeOf(a), nullptr);
477 auto* ref = TypeOf(a)->As<sem::Reference>();
478 ASSERT_NE(ref, nullptr);
479 auto* ary = ref->StoreType()->As<sem::Array>();
480 EXPECT_EQ(ary->Count(), 10u);
481 }
482
TEST_F(ResolverTest,Expr_Bitcast)483 TEST_F(ResolverTest, Expr_Bitcast) {
484 Global("name", ty.f32(), ast::StorageClass::kPrivate);
485
486 auto* bitcast = create<ast::BitcastExpression>(ty.f32(), Expr("name"));
487 WrapInFunction(bitcast);
488
489 EXPECT_TRUE(r()->Resolve()) << r()->error();
490
491 ASSERT_NE(TypeOf(bitcast), nullptr);
492 EXPECT_TRUE(TypeOf(bitcast)->Is<sem::F32>());
493 }
494
TEST_F(ResolverTest,Expr_Call)495 TEST_F(ResolverTest, Expr_Call) {
496 ast::VariableList params;
497 Func("my_func", params, ty.f32(), {Return(0.0f)}, ast::DecorationList{});
498
499 auto* call = Call("my_func");
500 WrapInFunction(call);
501
502 EXPECT_TRUE(r()->Resolve()) << r()->error();
503
504 ASSERT_NE(TypeOf(call), nullptr);
505 EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
506 }
507
TEST_F(ResolverTest,Expr_Call_InBinaryOp)508 TEST_F(ResolverTest, Expr_Call_InBinaryOp) {
509 ast::VariableList params;
510 Func("func", params, ty.f32(), {Return(0.0f)}, ast::DecorationList{});
511
512 auto* expr = Add(Call("func"), Call("func"));
513 WrapInFunction(expr);
514
515 EXPECT_TRUE(r()->Resolve()) << r()->error();
516
517 ASSERT_NE(TypeOf(expr), nullptr);
518 EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
519 }
520
TEST_F(ResolverTest,Expr_Call_WithParams)521 TEST_F(ResolverTest, Expr_Call_WithParams) {
522 Func("my_func", {Param(Sym(), ty.f32())}, ty.f32(),
523 {
524 Return(1.2f),
525 });
526
527 auto* param = Expr(2.4f);
528
529 auto* call = Call("my_func", param);
530 WrapInFunction(call);
531
532 EXPECT_TRUE(r()->Resolve()) << r()->error();
533
534 ASSERT_NE(TypeOf(param), nullptr);
535 EXPECT_TRUE(TypeOf(param)->Is<sem::F32>());
536 }
537
TEST_F(ResolverTest,Expr_Call_Intrinsic)538 TEST_F(ResolverTest, Expr_Call_Intrinsic) {
539 auto* call = Call("round", 2.4f);
540 WrapInFunction(call);
541
542 EXPECT_TRUE(r()->Resolve()) << r()->error();
543
544 ASSERT_NE(TypeOf(call), nullptr);
545 EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
546 }
547
TEST_F(ResolverTest,Expr_Cast)548 TEST_F(ResolverTest, Expr_Cast) {
549 Global("name", ty.f32(), ast::StorageClass::kPrivate);
550
551 auto* cast = Construct(ty.f32(), "name");
552 WrapInFunction(cast);
553
554 EXPECT_TRUE(r()->Resolve()) << r()->error();
555
556 ASSERT_NE(TypeOf(cast), nullptr);
557 EXPECT_TRUE(TypeOf(cast)->Is<sem::F32>());
558 }
559
TEST_F(ResolverTest,Expr_Constructor_Scalar)560 TEST_F(ResolverTest, Expr_Constructor_Scalar) {
561 auto* s = Expr(1.0f);
562 WrapInFunction(s);
563
564 EXPECT_TRUE(r()->Resolve()) << r()->error();
565
566 ASSERT_NE(TypeOf(s), nullptr);
567 EXPECT_TRUE(TypeOf(s)->Is<sem::F32>());
568 }
569
TEST_F(ResolverTest,Expr_Constructor_Type_Vec2)570 TEST_F(ResolverTest, Expr_Constructor_Type_Vec2) {
571 auto* tc = vec2<f32>(1.0f, 1.0f);
572 WrapInFunction(tc);
573
574 EXPECT_TRUE(r()->Resolve()) << r()->error();
575
576 ASSERT_NE(TypeOf(tc), nullptr);
577 ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
578 EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
579 EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 2u);
580 }
581
TEST_F(ResolverTest,Expr_Constructor_Type_Vec3)582 TEST_F(ResolverTest, Expr_Constructor_Type_Vec3) {
583 auto* tc = vec3<f32>(1.0f, 1.0f, 1.0f);
584 WrapInFunction(tc);
585
586 EXPECT_TRUE(r()->Resolve()) << r()->error();
587
588 ASSERT_NE(TypeOf(tc), nullptr);
589 ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
590 EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
591 EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 3u);
592 }
593
TEST_F(ResolverTest,Expr_Constructor_Type_Vec4)594 TEST_F(ResolverTest, Expr_Constructor_Type_Vec4) {
595 auto* tc = vec4<f32>(1.0f, 1.0f, 1.0f, 1.0f);
596 WrapInFunction(tc);
597
598 EXPECT_TRUE(r()->Resolve()) << r()->error();
599
600 ASSERT_NE(TypeOf(tc), nullptr);
601 ASSERT_TRUE(TypeOf(tc)->Is<sem::Vector>());
602 EXPECT_TRUE(TypeOf(tc)->As<sem::Vector>()->type()->Is<sem::F32>());
603 EXPECT_EQ(TypeOf(tc)->As<sem::Vector>()->Width(), 4u);
604 }
605
TEST_F(ResolverTest,Expr_Identifier_GlobalVariable)606 TEST_F(ResolverTest, Expr_Identifier_GlobalVariable) {
607 auto* my_var = Global("my_var", ty.f32(), ast::StorageClass::kPrivate);
608
609 auto* ident = Expr("my_var");
610 WrapInFunction(ident);
611
612 EXPECT_TRUE(r()->Resolve()) << r()->error();
613
614 ASSERT_NE(TypeOf(ident), nullptr);
615 ASSERT_TRUE(TypeOf(ident)->Is<sem::Reference>());
616 EXPECT_TRUE(TypeOf(ident)->UnwrapRef()->Is<sem::F32>());
617 EXPECT_TRUE(CheckVarUsers(my_var, {ident}));
618 ASSERT_NE(VarOf(ident), nullptr);
619 EXPECT_EQ(VarOf(ident)->Declaration(), my_var);
620 }
621
TEST_F(ResolverTest,Expr_Identifier_GlobalConstant)622 TEST_F(ResolverTest, Expr_Identifier_GlobalConstant) {
623 auto* my_var = GlobalConst("my_var", ty.f32(), Construct(ty.f32()));
624
625 auto* ident = Expr("my_var");
626 WrapInFunction(ident);
627
628 EXPECT_TRUE(r()->Resolve()) << r()->error();
629
630 ASSERT_NE(TypeOf(ident), nullptr);
631 EXPECT_TRUE(TypeOf(ident)->Is<sem::F32>());
632 EXPECT_TRUE(CheckVarUsers(my_var, {ident}));
633 ASSERT_NE(VarOf(ident), nullptr);
634 EXPECT_EQ(VarOf(ident)->Declaration(), my_var);
635 }
636
TEST_F(ResolverTest,Expr_Identifier_FunctionVariable_Const)637 TEST_F(ResolverTest, Expr_Identifier_FunctionVariable_Const) {
638 auto* my_var_a = Expr("my_var");
639 auto* var = Const("my_var", ty.f32(), Construct(ty.f32()));
640 auto* decl = Decl(Var("b", ty.f32(), ast::StorageClass::kNone, my_var_a));
641
642 Func("my_func", ast::VariableList{}, ty.void_(),
643 {
644 Decl(var),
645 decl,
646 },
647 ast::DecorationList{});
648
649 EXPECT_TRUE(r()->Resolve()) << r()->error();
650
651 ASSERT_NE(TypeOf(my_var_a), nullptr);
652 EXPECT_TRUE(TypeOf(my_var_a)->Is<sem::F32>());
653 EXPECT_EQ(StmtOf(my_var_a), decl);
654 EXPECT_TRUE(CheckVarUsers(var, {my_var_a}));
655 ASSERT_NE(VarOf(my_var_a), nullptr);
656 EXPECT_EQ(VarOf(my_var_a)->Declaration(), var);
657 }
658
TEST_F(ResolverTest,IndexAccessor_Dynamic_Ref_F32)659 TEST_F(ResolverTest, IndexAccessor_Dynamic_Ref_F32) {
660 // var a : array<bool, 10> = 0;
661 // var idx : f32 = f32();
662 // var f : f32 = a[idx];
663 auto* a = Var("a", ty.array<bool, 10>(), array<bool, 10>());
664 auto* idx = Var("idx", ty.f32(), Construct(ty.f32()));
665 auto* f = Var("f", ty.f32(), IndexAccessor("a", Expr(Source{{12, 34}}, idx)));
666 Func("my_func", ast::VariableList{}, ty.void_(),
667 {
668 Decl(a),
669 Decl(idx),
670 Decl(f),
671 },
672 ast::DecorationList{});
673
674 EXPECT_FALSE(r()->Resolve());
675 EXPECT_EQ(r()->error(),
676 "12:34 error: index must be of type 'i32' or 'u32', found: 'f32'");
677 }
678
TEST_F(ResolverTest,Expr_Identifier_FunctionVariable)679 TEST_F(ResolverTest, Expr_Identifier_FunctionVariable) {
680 auto* my_var_a = Expr("my_var");
681 auto* my_var_b = Expr("my_var");
682 auto* assign = Assign(my_var_a, my_var_b);
683
684 auto* var = Var("my_var", ty.f32());
685
686 Func("my_func", ast::VariableList{}, ty.void_(),
687 {
688 Decl(var),
689 assign,
690 },
691 ast::DecorationList{});
692
693 EXPECT_TRUE(r()->Resolve()) << r()->error();
694
695 ASSERT_NE(TypeOf(my_var_a), nullptr);
696 ASSERT_TRUE(TypeOf(my_var_a)->Is<sem::Reference>());
697 EXPECT_TRUE(TypeOf(my_var_a)->UnwrapRef()->Is<sem::F32>());
698 EXPECT_EQ(StmtOf(my_var_a), assign);
699 ASSERT_NE(TypeOf(my_var_b), nullptr);
700 ASSERT_TRUE(TypeOf(my_var_b)->Is<sem::Reference>());
701 EXPECT_TRUE(TypeOf(my_var_b)->UnwrapRef()->Is<sem::F32>());
702 EXPECT_EQ(StmtOf(my_var_b), assign);
703 EXPECT_TRUE(CheckVarUsers(var, {my_var_a, my_var_b}));
704 ASSERT_NE(VarOf(my_var_a), nullptr);
705 EXPECT_EQ(VarOf(my_var_a)->Declaration(), var);
706 ASSERT_NE(VarOf(my_var_b), nullptr);
707 EXPECT_EQ(VarOf(my_var_b)->Declaration(), var);
708 }
709
TEST_F(ResolverTest,Expr_Identifier_Function_Ptr)710 TEST_F(ResolverTest, Expr_Identifier_Function_Ptr) {
711 auto* v = Expr("v");
712 auto* p = Expr("p");
713 auto* v_decl = Decl(Var("v", ty.f32()));
714 auto* p_decl = Decl(
715 Const("p", ty.pointer<f32>(ast::StorageClass::kFunction), AddressOf(v)));
716 auto* assign = Assign(Deref(p), 1.23f);
717 Func("my_func", ast::VariableList{}, ty.void_(),
718 {
719 v_decl,
720 p_decl,
721 assign,
722 },
723 ast::DecorationList{});
724
725 EXPECT_TRUE(r()->Resolve()) << r()->error();
726
727 ASSERT_NE(TypeOf(v), nullptr);
728 ASSERT_TRUE(TypeOf(v)->Is<sem::Reference>());
729 EXPECT_TRUE(TypeOf(v)->UnwrapRef()->Is<sem::F32>());
730 EXPECT_EQ(StmtOf(v), p_decl);
731 ASSERT_NE(TypeOf(p), nullptr);
732 ASSERT_TRUE(TypeOf(p)->Is<sem::Pointer>());
733 EXPECT_TRUE(TypeOf(p)->UnwrapPtr()->Is<sem::F32>());
734 EXPECT_EQ(StmtOf(p), assign);
735 }
736
TEST_F(ResolverTest,Expr_Call_Function)737 TEST_F(ResolverTest, Expr_Call_Function) {
738 Func("my_func", ast::VariableList{}, ty.f32(), {Return(0.0f)},
739 ast::DecorationList{});
740
741 auto* call = Call("my_func");
742 WrapInFunction(call);
743
744 EXPECT_TRUE(r()->Resolve()) << r()->error();
745
746 ASSERT_NE(TypeOf(call), nullptr);
747 EXPECT_TRUE(TypeOf(call)->Is<sem::F32>());
748 }
749
TEST_F(ResolverTest,Expr_Identifier_Unknown)750 TEST_F(ResolverTest, Expr_Identifier_Unknown) {
751 auto* a = Expr("a");
752 WrapInFunction(a);
753
754 EXPECT_FALSE(r()->Resolve());
755 }
756
TEST_F(ResolverTest,Function_Parameters)757 TEST_F(ResolverTest, Function_Parameters) {
758 auto* param_a = Param("a", ty.f32());
759 auto* param_b = Param("b", ty.i32());
760 auto* param_c = Param("c", ty.u32());
761
762 auto* func = Func("my_func",
763 ast::VariableList{
764 param_a,
765 param_b,
766 param_c,
767 },
768 ty.void_(), {});
769
770 EXPECT_TRUE(r()->Resolve()) << r()->error();
771
772 auto* func_sem = Sem().Get(func);
773 ASSERT_NE(func_sem, nullptr);
774 EXPECT_EQ(func_sem->Parameters().size(), 3u);
775 EXPECT_TRUE(func_sem->Parameters()[0]->Type()->Is<sem::F32>());
776 EXPECT_TRUE(func_sem->Parameters()[1]->Type()->Is<sem::I32>());
777 EXPECT_TRUE(func_sem->Parameters()[2]->Type()->Is<sem::U32>());
778 EXPECT_EQ(func_sem->Parameters()[0]->Declaration(), param_a);
779 EXPECT_EQ(func_sem->Parameters()[1]->Declaration(), param_b);
780 EXPECT_EQ(func_sem->Parameters()[2]->Declaration(), param_c);
781 EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
782 }
783
TEST_F(ResolverTest,Function_RegisterInputOutputVariables)784 TEST_F(ResolverTest, Function_RegisterInputOutputVariables) {
785 auto* s = Structure("S", {Member("m", ty.u32())},
786 {create<ast::StructBlockDecoration>()});
787
788 auto* sb_var = Global("sb_var", ty.Of(s), ast::StorageClass::kStorage,
789 ast::Access::kReadWrite,
790 ast::DecorationList{
791 create<ast::BindingDecoration>(0),
792 create<ast::GroupDecoration>(0),
793 });
794 auto* wg_var = Global("wg_var", ty.f32(), ast::StorageClass::kWorkgroup);
795 auto* priv_var = Global("priv_var", ty.f32(), ast::StorageClass::kPrivate);
796
797 auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
798 {
799 Assign("wg_var", "wg_var"),
800 Assign("sb_var", "sb_var"),
801 Assign("priv_var", "priv_var"),
802 });
803
804 EXPECT_TRUE(r()->Resolve()) << r()->error();
805
806 auto* func_sem = Sem().Get(func);
807 ASSERT_NE(func_sem, nullptr);
808 EXPECT_EQ(func_sem->Parameters().size(), 0u);
809 EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
810
811 const auto& vars = func_sem->TransitivelyReferencedGlobals();
812 ASSERT_EQ(vars.size(), 3u);
813 EXPECT_EQ(vars[0]->Declaration(), wg_var);
814 EXPECT_EQ(vars[1]->Declaration(), sb_var);
815 EXPECT_EQ(vars[2]->Declaration(), priv_var);
816 }
817
TEST_F(ResolverTest,Function_RegisterInputOutputVariables_SubFunction)818 TEST_F(ResolverTest, Function_RegisterInputOutputVariables_SubFunction) {
819 auto* s = Structure("S", {Member("m", ty.u32())},
820 {create<ast::StructBlockDecoration>()});
821
822 auto* sb_var = Global("sb_var", ty.Of(s), ast::StorageClass::kStorage,
823 ast::Access::kReadWrite,
824 ast::DecorationList{
825 create<ast::BindingDecoration>(0),
826 create<ast::GroupDecoration>(0),
827 });
828 auto* wg_var = Global("wg_var", ty.f32(), ast::StorageClass::kWorkgroup);
829 auto* priv_var = Global("priv_var", ty.f32(), ast::StorageClass::kPrivate);
830
831 Func("my_func", ast::VariableList{}, ty.f32(),
832 {Assign("wg_var", "wg_var"), Assign("sb_var", "sb_var"),
833 Assign("priv_var", "priv_var"), Return(0.0f)},
834 ast::DecorationList{});
835
836 auto* func2 = Func("func", ast::VariableList{}, ty.void_(),
837 {
838 WrapInStatement(Call("my_func")),
839 },
840 ast::DecorationList{});
841
842 EXPECT_TRUE(r()->Resolve()) << r()->error();
843
844 auto* func2_sem = Sem().Get(func2);
845 ASSERT_NE(func2_sem, nullptr);
846 EXPECT_EQ(func2_sem->Parameters().size(), 0u);
847
848 const auto& vars = func2_sem->TransitivelyReferencedGlobals();
849 ASSERT_EQ(vars.size(), 3u);
850 EXPECT_EQ(vars[0]->Declaration(), wg_var);
851 EXPECT_EQ(vars[1]->Declaration(), sb_var);
852 EXPECT_EQ(vars[2]->Declaration(), priv_var);
853 }
854
TEST_F(ResolverTest,Function_NotRegisterFunctionVariable)855 TEST_F(ResolverTest, Function_NotRegisterFunctionVariable) {
856 auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
857 {
858 Decl(Var("var", ty.f32())),
859 Assign("var", 1.f),
860 });
861
862 EXPECT_TRUE(r()->Resolve()) << r()->error();
863
864 auto* func_sem = Sem().Get(func);
865 ASSERT_NE(func_sem, nullptr);
866
867 EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
868 EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
869 }
870
TEST_F(ResolverTest,Function_NotRegisterFunctionConstant)871 TEST_F(ResolverTest, Function_NotRegisterFunctionConstant) {
872 auto* func = Func("my_func", ast::VariableList{}, ty.void_(),
873 {
874 Decl(Const("var", ty.f32(), Construct(ty.f32()))),
875 });
876
877 EXPECT_TRUE(r()->Resolve()) << r()->error();
878
879 auto* func_sem = Sem().Get(func);
880 ASSERT_NE(func_sem, nullptr);
881
882 EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
883 EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
884 }
885
TEST_F(ResolverTest,Function_NotRegisterFunctionParams)886 TEST_F(ResolverTest, Function_NotRegisterFunctionParams) {
887 auto* func = Func("my_func", {Const("var", ty.f32(), Construct(ty.f32()))},
888 ty.void_(), {});
889 EXPECT_TRUE(r()->Resolve()) << r()->error();
890
891 auto* func_sem = Sem().Get(func);
892 ASSERT_NE(func_sem, nullptr);
893
894 EXPECT_EQ(func_sem->TransitivelyReferencedGlobals().size(), 0u);
895 EXPECT_TRUE(func_sem->ReturnType()->Is<sem::Void>());
896 }
897
TEST_F(ResolverTest,Function_CallSites)898 TEST_F(ResolverTest, Function_CallSites) {
899 auto* foo = Func("foo", ast::VariableList{}, ty.void_(), {});
900
901 auto* call_1 = Call("foo");
902 auto* call_2 = Call("foo");
903 auto* bar = Func("bar", ast::VariableList{}, ty.void_(),
904 {
905 CallStmt(call_1),
906 CallStmt(call_2),
907 });
908
909 EXPECT_TRUE(r()->Resolve()) << r()->error();
910
911 auto* foo_sem = Sem().Get(foo);
912 ASSERT_NE(foo_sem, nullptr);
913 ASSERT_EQ(foo_sem->CallSites().size(), 2u);
914 EXPECT_EQ(foo_sem->CallSites()[0]->Declaration(), call_1);
915 EXPECT_EQ(foo_sem->CallSites()[1]->Declaration(), call_2);
916
917 auto* bar_sem = Sem().Get(bar);
918 ASSERT_NE(bar_sem, nullptr);
919 EXPECT_EQ(bar_sem->CallSites().size(), 0u);
920 }
921
TEST_F(ResolverTest,Function_WorkgroupSize_NotSet)922 TEST_F(ResolverTest, Function_WorkgroupSize_NotSet) {
923 // [[stage(compute), workgroup_size(1)]]
924 // fn main() {}
925 auto* func = Func("main", ast::VariableList{}, ty.void_(), {}, {});
926
927 EXPECT_TRUE(r()->Resolve()) << r()->error();
928
929 auto* func_sem = Sem().Get(func);
930 ASSERT_NE(func_sem, nullptr);
931
932 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 1u);
933 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 1u);
934 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u);
935 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
936 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
937 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
938 }
939
TEST_F(ResolverTest,Function_WorkgroupSize_Literals)940 TEST_F(ResolverTest, Function_WorkgroupSize_Literals) {
941 // [[stage(compute), workgroup_size(8, 2, 3)]]
942 // fn main() {}
943 auto* func =
944 Func("main", ast::VariableList{}, ty.void_(), {},
945 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(8, 2, 3)});
946
947 EXPECT_TRUE(r()->Resolve()) << r()->error();
948
949 auto* func_sem = Sem().Get(func);
950 ASSERT_NE(func_sem, nullptr);
951
952 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
953 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u);
954 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u);
955 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
956 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
957 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
958 }
959
TEST_F(ResolverTest,Function_WorkgroupSize_Consts)960 TEST_F(ResolverTest, Function_WorkgroupSize_Consts) {
961 // let width = 16;
962 // let height = 8;
963 // let depth = 2;
964 // [[stage(compute), workgroup_size(width, height, depth)]]
965 // fn main() {}
966 GlobalConst("width", ty.i32(), Expr(16));
967 GlobalConst("height", ty.i32(), Expr(8));
968 GlobalConst("depth", ty.i32(), Expr(2));
969 auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
970 {Stage(ast::PipelineStage::kCompute),
971 WorkgroupSize("width", "height", "depth")});
972
973 EXPECT_TRUE(r()->Resolve()) << r()->error();
974
975 auto* func_sem = Sem().Get(func);
976 ASSERT_NE(func_sem, nullptr);
977
978 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u);
979 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u);
980 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u);
981 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
982 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
983 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
984 }
985
TEST_F(ResolverTest,Function_WorkgroupSize_Consts_NestedInitializer)986 TEST_F(ResolverTest, Function_WorkgroupSize_Consts_NestedInitializer) {
987 // let width = i32(i32(i32(8)));
988 // let height = i32(i32(i32(4)));
989 // [[stage(compute), workgroup_size(width, height)]]
990 // fn main() {}
991 GlobalConst("width", ty.i32(),
992 Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 8))));
993 GlobalConst("height", ty.i32(),
994 Construct(ty.i32(), Construct(ty.i32(), Construct(ty.i32(), 4))));
995 auto* func = Func(
996 "main", ast::VariableList{}, ty.void_(), {},
997 {Stage(ast::PipelineStage::kCompute), WorkgroupSize("width", "height")});
998
999 EXPECT_TRUE(r()->Resolve()) << r()->error();
1000
1001 auto* func_sem = Sem().Get(func);
1002 ASSERT_NE(func_sem, nullptr);
1003
1004 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
1005 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 4u);
1006 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 1u);
1007 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
1008 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, nullptr);
1009 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
1010 }
1011
TEST_F(ResolverTest,Function_WorkgroupSize_OverridableConsts)1012 TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts) {
1013 // [[override(0)]] let width = 16;
1014 // [[override(1)]] let height = 8;
1015 // [[override(2)]] let depth = 2;
1016 // [[stage(compute), workgroup_size(width, height, depth)]]
1017 // fn main() {}
1018 auto* width = GlobalConst("width", ty.i32(), Expr(16), {Override(0)});
1019 auto* height = GlobalConst("height", ty.i32(), Expr(8), {Override(1)});
1020 auto* depth = GlobalConst("depth", ty.i32(), Expr(2), {Override(2)});
1021 auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
1022 {Stage(ast::PipelineStage::kCompute),
1023 WorkgroupSize("width", "height", "depth")});
1024
1025 EXPECT_TRUE(r()->Resolve()) << r()->error();
1026
1027 auto* func_sem = Sem().Get(func);
1028 ASSERT_NE(func_sem, nullptr);
1029
1030 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 16u);
1031 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 8u);
1032 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 2u);
1033 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
1034 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
1035 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
1036 }
1037
TEST_F(ResolverTest,Function_WorkgroupSize_OverridableConsts_NoInit)1038 TEST_F(ResolverTest, Function_WorkgroupSize_OverridableConsts_NoInit) {
1039 // [[override(0)]] let width : i32;
1040 // [[override(1)]] let height : i32;
1041 // [[override(2)]] let depth : i32;
1042 // [[stage(compute), workgroup_size(width, height, depth)]]
1043 // fn main() {}
1044 auto* width = GlobalConst("width", ty.i32(), nullptr, {Override(0)});
1045 auto* height = GlobalConst("height", ty.i32(), nullptr, {Override(1)});
1046 auto* depth = GlobalConst("depth", ty.i32(), nullptr, {Override(2)});
1047 auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
1048 {Stage(ast::PipelineStage::kCompute),
1049 WorkgroupSize("width", "height", "depth")});
1050
1051 EXPECT_TRUE(r()->Resolve()) << r()->error();
1052
1053 auto* func_sem = Sem().Get(func);
1054 ASSERT_NE(func_sem, nullptr);
1055
1056 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 0u);
1057 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 0u);
1058 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 0u);
1059 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, width);
1060 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
1061 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, depth);
1062 }
1063
TEST_F(ResolverTest,Function_WorkgroupSize_Mixed)1064 TEST_F(ResolverTest, Function_WorkgroupSize_Mixed) {
1065 // [[override(1)]] let height = 2;
1066 // let depth = 3;
1067 // [[stage(compute), workgroup_size(8, height, depth)]]
1068 // fn main() {}
1069 auto* height = GlobalConst("height", ty.i32(), Expr(2), {Override(0)});
1070 GlobalConst("depth", ty.i32(), Expr(3));
1071 auto* func = Func("main", ast::VariableList{}, ty.void_(), {},
1072 {Stage(ast::PipelineStage::kCompute),
1073 WorkgroupSize(8, "height", "depth")});
1074
1075 EXPECT_TRUE(r()->Resolve()) << r()->error();
1076
1077 auto* func_sem = Sem().Get(func);
1078 ASSERT_NE(func_sem, nullptr);
1079
1080 EXPECT_EQ(func_sem->WorkgroupSize()[0].value, 8u);
1081 EXPECT_EQ(func_sem->WorkgroupSize()[1].value, 2u);
1082 EXPECT_EQ(func_sem->WorkgroupSize()[2].value, 3u);
1083 EXPECT_EQ(func_sem->WorkgroupSize()[0].overridable_const, nullptr);
1084 EXPECT_EQ(func_sem->WorkgroupSize()[1].overridable_const, height);
1085 EXPECT_EQ(func_sem->WorkgroupSize()[2].overridable_const, nullptr);
1086 }
1087
TEST_F(ResolverTest,Expr_MemberAccessor_Struct)1088 TEST_F(ResolverTest, Expr_MemberAccessor_Struct) {
1089 auto* st = Structure("S", {Member("first_member", ty.i32()),
1090 Member("second_member", ty.f32())});
1091 Global("my_struct", ty.Of(st), ast::StorageClass::kPrivate);
1092
1093 auto* mem = MemberAccessor("my_struct", "second_member");
1094 WrapInFunction(mem);
1095
1096 EXPECT_TRUE(r()->Resolve()) << r()->error();
1097
1098 ASSERT_NE(TypeOf(mem), nullptr);
1099 ASSERT_TRUE(TypeOf(mem)->Is<sem::Reference>());
1100
1101 auto* ref = TypeOf(mem)->As<sem::Reference>();
1102 EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
1103 auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
1104 ASSERT_NE(sma, nullptr);
1105 EXPECT_TRUE(sma->Member()->Type()->Is<sem::F32>());
1106 EXPECT_EQ(sma->Member()->Index(), 1u);
1107 EXPECT_EQ(sma->Member()->Declaration()->symbol,
1108 Symbols().Get("second_member"));
1109 }
1110
TEST_F(ResolverTest,Expr_MemberAccessor_Struct_Alias)1111 TEST_F(ResolverTest, Expr_MemberAccessor_Struct_Alias) {
1112 auto* st = Structure("S", {Member("first_member", ty.i32()),
1113 Member("second_member", ty.f32())});
1114 auto* alias = Alias("alias", ty.Of(st));
1115 Global("my_struct", ty.Of(alias), ast::StorageClass::kPrivate);
1116
1117 auto* mem = MemberAccessor("my_struct", "second_member");
1118 WrapInFunction(mem);
1119
1120 EXPECT_TRUE(r()->Resolve()) << r()->error();
1121
1122 ASSERT_NE(TypeOf(mem), nullptr);
1123 ASSERT_TRUE(TypeOf(mem)->Is<sem::Reference>());
1124
1125 auto* ref = TypeOf(mem)->As<sem::Reference>();
1126 EXPECT_TRUE(ref->StoreType()->Is<sem::F32>());
1127 auto* sma = Sem().Get(mem)->As<sem::StructMemberAccess>();
1128 ASSERT_NE(sma, nullptr);
1129 EXPECT_TRUE(sma->Member()->Type()->Is<sem::F32>());
1130 EXPECT_EQ(sma->Member()->Index(), 1u);
1131 }
1132
TEST_F(ResolverTest,Expr_MemberAccessor_VectorSwizzle)1133 TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle) {
1134 Global("my_vec", ty.vec4<f32>(), ast::StorageClass::kPrivate);
1135
1136 auto* mem = MemberAccessor("my_vec", "xzyw");
1137 WrapInFunction(mem);
1138
1139 EXPECT_TRUE(r()->Resolve()) << r()->error();
1140
1141 ASSERT_NE(TypeOf(mem), nullptr);
1142 ASSERT_TRUE(TypeOf(mem)->Is<sem::Vector>());
1143 EXPECT_TRUE(TypeOf(mem)->As<sem::Vector>()->type()->Is<sem::F32>());
1144 EXPECT_EQ(TypeOf(mem)->As<sem::Vector>()->Width(), 4u);
1145 ASSERT_TRUE(Sem().Get(mem)->Is<sem::Swizzle>());
1146 EXPECT_THAT(Sem().Get(mem)->As<sem::Swizzle>()->Indices(),
1147 ElementsAre(0, 2, 1, 3));
1148 }
1149
TEST_F(ResolverTest,Expr_MemberAccessor_VectorSwizzle_SingleElement)1150 TEST_F(ResolverTest, Expr_MemberAccessor_VectorSwizzle_SingleElement) {
1151 Global("my_vec", ty.vec3<f32>(), ast::StorageClass::kPrivate);
1152
1153 auto* mem = MemberAccessor("my_vec", "b");
1154 WrapInFunction(mem);
1155
1156 EXPECT_TRUE(r()->Resolve()) << r()->error();
1157
1158 ASSERT_NE(TypeOf(mem), nullptr);
1159 ASSERT_TRUE(TypeOf(mem)->Is<sem::Reference>());
1160
1161 auto* ref = TypeOf(mem)->As<sem::Reference>();
1162 ASSERT_TRUE(ref->StoreType()->Is<sem::F32>());
1163 ASSERT_TRUE(Sem().Get(mem)->Is<sem::Swizzle>());
1164 EXPECT_THAT(Sem().Get(mem)->As<sem::Swizzle>()->Indices(), ElementsAre(2));
1165 }
1166
TEST_F(ResolverTest,Expr_Accessor_MultiLevel)1167 TEST_F(ResolverTest, Expr_Accessor_MultiLevel) {
1168 // struct b {
1169 // vec4<f32> foo
1170 // }
1171 // struct A {
1172 // array<b, 3> mem
1173 // }
1174 // var c : A
1175 // c.mem[0].foo.yx
1176 // -> vec2<f32>
1177 //
1178 // fn f() {
1179 // c.mem[0].foo
1180 // }
1181 //
1182
1183 auto* stB = Structure("B", {Member("foo", ty.vec4<f32>())});
1184 auto* stA = Structure("A", {Member("mem", ty.array(ty.Of(stB), 3))});
1185 Global("c", ty.Of(stA), ast::StorageClass::kPrivate);
1186
1187 auto* mem = MemberAccessor(
1188 MemberAccessor(IndexAccessor(MemberAccessor("c", "mem"), 0), "foo"),
1189 "yx");
1190 WrapInFunction(mem);
1191
1192 EXPECT_TRUE(r()->Resolve()) << r()->error();
1193
1194 ASSERT_NE(TypeOf(mem), nullptr);
1195 ASSERT_TRUE(TypeOf(mem)->Is<sem::Vector>());
1196 EXPECT_TRUE(TypeOf(mem)->As<sem::Vector>()->type()->Is<sem::F32>());
1197 EXPECT_EQ(TypeOf(mem)->As<sem::Vector>()->Width(), 2u);
1198 ASSERT_TRUE(Sem().Get(mem)->Is<sem::Swizzle>());
1199 }
1200
TEST_F(ResolverTest,Expr_MemberAccessor_InBinaryOp)1201 TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
1202 auto* st = Structure("S", {Member("first_member", ty.f32()),
1203 Member("second_member", ty.f32())});
1204 Global("my_struct", ty.Of(st), ast::StorageClass::kPrivate);
1205
1206 auto* expr = Add(MemberAccessor("my_struct", "first_member"),
1207 MemberAccessor("my_struct", "second_member"));
1208 WrapInFunction(expr);
1209
1210 EXPECT_TRUE(r()->Resolve()) << r()->error();
1211
1212 ASSERT_NE(TypeOf(expr), nullptr);
1213 EXPECT_TRUE(TypeOf(expr)->Is<sem::F32>());
1214 }
1215
1216 namespace ExprBinaryTest {
1217
1218 template <typename T, int ID>
1219 struct Aliased {
1220 using type = alias<T, ID>;
1221 };
1222
1223 template <int N, typename T, int ID>
1224 struct Aliased<vec<N, T>, ID> {
1225 using type = vec<N, alias<T, ID>>;
1226 };
1227
1228 template <int N, int M, typename T, int ID>
1229 struct Aliased<mat<N, M, T>, ID> {
1230 using type = mat<N, M, alias<T, ID>>;
1231 };
1232
1233 struct Params {
1234 ast::BinaryOp op;
1235 builder::ast_type_func_ptr create_lhs_type;
1236 builder::ast_type_func_ptr create_rhs_type;
1237 builder::ast_type_func_ptr create_lhs_alias_type;
1238 builder::ast_type_func_ptr create_rhs_alias_type;
1239 builder::sem_type_func_ptr create_result_type;
1240 };
1241
1242 template <typename LHS, typename RHS, typename RES>
ParamsFor(ast::BinaryOp op)1243 constexpr Params ParamsFor(ast::BinaryOp op) {
1244 return Params{op,
1245 DataType<LHS>::AST,
1246 DataType<RHS>::AST,
1247 DataType<typename Aliased<LHS, 0>::type>::AST,
1248 DataType<typename Aliased<RHS, 1>::type>::AST,
1249 DataType<RES>::Sem};
1250 }
1251
1252 static constexpr ast::BinaryOp all_ops[] = {
1253 ast::BinaryOp::kAnd,
1254 ast::BinaryOp::kOr,
1255 ast::BinaryOp::kXor,
1256 ast::BinaryOp::kLogicalAnd,
1257 ast::BinaryOp::kLogicalOr,
1258 ast::BinaryOp::kEqual,
1259 ast::BinaryOp::kNotEqual,
1260 ast::BinaryOp::kLessThan,
1261 ast::BinaryOp::kGreaterThan,
1262 ast::BinaryOp::kLessThanEqual,
1263 ast::BinaryOp::kGreaterThanEqual,
1264 ast::BinaryOp::kShiftLeft,
1265 ast::BinaryOp::kShiftRight,
1266 ast::BinaryOp::kAdd,
1267 ast::BinaryOp::kSubtract,
1268 ast::BinaryOp::kMultiply,
1269 ast::BinaryOp::kDivide,
1270 ast::BinaryOp::kModulo,
1271 };
1272
1273 static constexpr builder::ast_type_func_ptr all_create_type_funcs[] = {
1274 DataType<bool>::AST, //
1275 DataType<u32>::AST, //
1276 DataType<i32>::AST, //
1277 DataType<f32>::AST, //
1278 DataType<vec3<bool>>::AST, //
1279 DataType<vec3<i32>>::AST, //
1280 DataType<vec3<u32>>::AST, //
1281 DataType<vec3<f32>>::AST, //
1282 DataType<mat3x3<f32>>::AST, //
1283 DataType<mat2x3<f32>>::AST, //
1284 DataType<mat3x2<f32>>::AST //
1285 };
1286
1287 // A list of all valid test cases for 'lhs op rhs', except that for vecN and
1288 // matNxN, we only test N=3.
1289 static constexpr Params all_valid_cases[] = {
1290 // Logical expressions
1291 // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
1292
1293 // Binary logical expressions
1294 ParamsFor<bool, bool, bool>(Op::kLogicalAnd),
1295 ParamsFor<bool, bool, bool>(Op::kLogicalOr),
1296
1297 ParamsFor<bool, bool, bool>(Op::kAnd),
1298 ParamsFor<bool, bool, bool>(Op::kOr),
1299 ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kAnd),
1300 ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kOr),
1301
1302 // Arithmetic expressions
1303 // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
1304
1305 // Binary arithmetic expressions over scalars
1306 ParamsFor<i32, i32, i32>(Op::kAdd),
1307 ParamsFor<i32, i32, i32>(Op::kSubtract),
1308 ParamsFor<i32, i32, i32>(Op::kMultiply),
1309 ParamsFor<i32, i32, i32>(Op::kDivide),
1310 ParamsFor<i32, i32, i32>(Op::kModulo),
1311
1312 ParamsFor<u32, u32, u32>(Op::kAdd),
1313 ParamsFor<u32, u32, u32>(Op::kSubtract),
1314 ParamsFor<u32, u32, u32>(Op::kMultiply),
1315 ParamsFor<u32, u32, u32>(Op::kDivide),
1316 ParamsFor<u32, u32, u32>(Op::kModulo),
1317
1318 ParamsFor<f32, f32, f32>(Op::kAdd),
1319 ParamsFor<f32, f32, f32>(Op::kSubtract),
1320 ParamsFor<f32, f32, f32>(Op::kMultiply),
1321 ParamsFor<f32, f32, f32>(Op::kDivide),
1322 ParamsFor<f32, f32, f32>(Op::kModulo),
1323
1324 // Binary arithmetic expressions over vectors
1325 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kAdd),
1326 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kSubtract),
1327 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kMultiply),
1328 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kDivide),
1329 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kModulo),
1330
1331 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kAdd),
1332 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kSubtract),
1333 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kMultiply),
1334 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kDivide),
1335 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kModulo),
1336
1337 ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kAdd),
1338 ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kSubtract),
1339 ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kMultiply),
1340 ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kDivide),
1341 ParamsFor<vec3<f32>, vec3<f32>, vec3<f32>>(Op::kModulo),
1342
1343 // Binary arithmetic expressions with mixed scalar and vector operands
1344 ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kAdd),
1345 ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kSubtract),
1346 ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kMultiply),
1347 ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kDivide),
1348 ParamsFor<vec3<i32>, i32, vec3<i32>>(Op::kModulo),
1349
1350 ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kAdd),
1351 ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kSubtract),
1352 ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kMultiply),
1353 ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kDivide),
1354 ParamsFor<i32, vec3<i32>, vec3<i32>>(Op::kModulo),
1355
1356 ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kAdd),
1357 ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kSubtract),
1358 ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kMultiply),
1359 ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kDivide),
1360 ParamsFor<vec3<u32>, u32, vec3<u32>>(Op::kModulo),
1361
1362 ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kAdd),
1363 ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kSubtract),
1364 ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kMultiply),
1365 ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kDivide),
1366 ParamsFor<u32, vec3<u32>, vec3<u32>>(Op::kModulo),
1367
1368 ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kAdd),
1369 ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kSubtract),
1370 ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kMultiply),
1371 ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kDivide),
1372 // NOTE: no kModulo for vec3<f32>, f32
1373 // ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kModulo),
1374
1375 ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kAdd),
1376 ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kSubtract),
1377 ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kMultiply),
1378 ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kDivide),
1379 // NOTE: no kModulo for f32, vec3<f32>
1380 // ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kModulo),
1381
1382 // Matrix arithmetic
1383 ParamsFor<mat2x3<f32>, f32, mat2x3<f32>>(Op::kMultiply),
1384 ParamsFor<mat3x2<f32>, f32, mat3x2<f32>>(Op::kMultiply),
1385 ParamsFor<mat3x3<f32>, f32, mat3x3<f32>>(Op::kMultiply),
1386
1387 ParamsFor<f32, mat2x3<f32>, mat2x3<f32>>(Op::kMultiply),
1388 ParamsFor<f32, mat3x2<f32>, mat3x2<f32>>(Op::kMultiply),
1389 ParamsFor<f32, mat3x3<f32>, mat3x3<f32>>(Op::kMultiply),
1390
1391 ParamsFor<vec3<f32>, mat2x3<f32>, vec2<f32>>(Op::kMultiply),
1392 ParamsFor<vec2<f32>, mat3x2<f32>, vec3<f32>>(Op::kMultiply),
1393 ParamsFor<vec3<f32>, mat3x3<f32>, vec3<f32>>(Op::kMultiply),
1394
1395 ParamsFor<mat3x2<f32>, vec3<f32>, vec2<f32>>(Op::kMultiply),
1396 ParamsFor<mat2x3<f32>, vec2<f32>, vec3<f32>>(Op::kMultiply),
1397 ParamsFor<mat3x3<f32>, vec3<f32>, vec3<f32>>(Op::kMultiply),
1398
1399 ParamsFor<mat2x3<f32>, mat3x2<f32>, mat3x3<f32>>(Op::kMultiply),
1400 ParamsFor<mat3x2<f32>, mat2x3<f32>, mat2x2<f32>>(Op::kMultiply),
1401 ParamsFor<mat3x2<f32>, mat3x3<f32>, mat3x2<f32>>(Op::kMultiply),
1402 ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kMultiply),
1403 ParamsFor<mat3x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kMultiply),
1404
1405 ParamsFor<mat2x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kAdd),
1406 ParamsFor<mat3x2<f32>, mat3x2<f32>, mat3x2<f32>>(Op::kAdd),
1407 ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kAdd),
1408
1409 ParamsFor<mat2x3<f32>, mat2x3<f32>, mat2x3<f32>>(Op::kSubtract),
1410 ParamsFor<mat3x2<f32>, mat3x2<f32>, mat3x2<f32>>(Op::kSubtract),
1411 ParamsFor<mat3x3<f32>, mat3x3<f32>, mat3x3<f32>>(Op::kSubtract),
1412
1413 // Comparison expressions
1414 // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
1415
1416 // Comparisons over scalars
1417 ParamsFor<bool, bool, bool>(Op::kEqual),
1418 ParamsFor<bool, bool, bool>(Op::kNotEqual),
1419
1420 ParamsFor<i32, i32, bool>(Op::kEqual),
1421 ParamsFor<i32, i32, bool>(Op::kNotEqual),
1422 ParamsFor<i32, i32, bool>(Op::kLessThan),
1423 ParamsFor<i32, i32, bool>(Op::kLessThanEqual),
1424 ParamsFor<i32, i32, bool>(Op::kGreaterThan),
1425 ParamsFor<i32, i32, bool>(Op::kGreaterThanEqual),
1426
1427 ParamsFor<u32, u32, bool>(Op::kEqual),
1428 ParamsFor<u32, u32, bool>(Op::kNotEqual),
1429 ParamsFor<u32, u32, bool>(Op::kLessThan),
1430 ParamsFor<u32, u32, bool>(Op::kLessThanEqual),
1431 ParamsFor<u32, u32, bool>(Op::kGreaterThan),
1432 ParamsFor<u32, u32, bool>(Op::kGreaterThanEqual),
1433
1434 ParamsFor<f32, f32, bool>(Op::kEqual),
1435 ParamsFor<f32, f32, bool>(Op::kNotEqual),
1436 ParamsFor<f32, f32, bool>(Op::kLessThan),
1437 ParamsFor<f32, f32, bool>(Op::kLessThanEqual),
1438 ParamsFor<f32, f32, bool>(Op::kGreaterThan),
1439 ParamsFor<f32, f32, bool>(Op::kGreaterThanEqual),
1440
1441 // Comparisons over vectors
1442 ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kEqual),
1443 ParamsFor<vec3<bool>, vec3<bool>, vec3<bool>>(Op::kNotEqual),
1444
1445 ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kEqual),
1446 ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kNotEqual),
1447 ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kLessThan),
1448 ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kLessThanEqual),
1449 ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kGreaterThan),
1450 ParamsFor<vec3<i32>, vec3<i32>, vec3<bool>>(Op::kGreaterThanEqual),
1451
1452 ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kEqual),
1453 ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kNotEqual),
1454 ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kLessThan),
1455 ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kLessThanEqual),
1456 ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kGreaterThan),
1457 ParamsFor<vec3<u32>, vec3<u32>, vec3<bool>>(Op::kGreaterThanEqual),
1458
1459 ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kEqual),
1460 ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kNotEqual),
1461 ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kLessThan),
1462 ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kLessThanEqual),
1463 ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kGreaterThan),
1464 ParamsFor<vec3<f32>, vec3<f32>, vec3<bool>>(Op::kGreaterThanEqual),
1465
1466 // Binary bitwise operations
1467 ParamsFor<i32, i32, i32>(Op::kOr),
1468 ParamsFor<i32, i32, i32>(Op::kAnd),
1469 ParamsFor<i32, i32, i32>(Op::kXor),
1470
1471 ParamsFor<u32, u32, u32>(Op::kOr),
1472 ParamsFor<u32, u32, u32>(Op::kAnd),
1473 ParamsFor<u32, u32, u32>(Op::kXor),
1474
1475 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kOr),
1476 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kAnd),
1477 ParamsFor<vec3<i32>, vec3<i32>, vec3<i32>>(Op::kXor),
1478
1479 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kOr),
1480 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kAnd),
1481 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kXor),
1482
1483 // Bit shift expressions
1484 ParamsFor<i32, u32, i32>(Op::kShiftLeft),
1485 ParamsFor<vec3<i32>, vec3<u32>, vec3<i32>>(Op::kShiftLeft),
1486
1487 ParamsFor<u32, u32, u32>(Op::kShiftLeft),
1488 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kShiftLeft),
1489
1490 ParamsFor<i32, u32, i32>(Op::kShiftRight),
1491 ParamsFor<vec3<i32>, vec3<u32>, vec3<i32>>(Op::kShiftRight),
1492
1493 ParamsFor<u32, u32, u32>(Op::kShiftRight),
1494 ParamsFor<vec3<u32>, vec3<u32>, vec3<u32>>(Op::kShiftRight),
1495 };
1496
1497 using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
TEST_P(Expr_Binary_Test_Valid,All)1498 TEST_P(Expr_Binary_Test_Valid, All) {
1499 auto& params = GetParam();
1500
1501 auto* lhs_type = params.create_lhs_type(*this);
1502 auto* rhs_type = params.create_rhs_type(*this);
1503 auto* result_type = params.create_result_type(*this);
1504
1505 std::stringstream ss;
1506 ss << FriendlyName(lhs_type) << " " << params.op << " "
1507 << FriendlyName(rhs_type);
1508 SCOPED_TRACE(ss.str());
1509
1510 Global("lhs", lhs_type, ast::StorageClass::kPrivate);
1511 Global("rhs", rhs_type, ast::StorageClass::kPrivate);
1512
1513 auto* expr =
1514 create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
1515 WrapInFunction(expr);
1516
1517 ASSERT_TRUE(r()->Resolve()) << r()->error();
1518 ASSERT_NE(TypeOf(expr), nullptr);
1519 ASSERT_TRUE(TypeOf(expr) == result_type);
1520 }
1521 INSTANTIATE_TEST_SUITE_P(ResolverTest,
1522 Expr_Binary_Test_Valid,
1523 testing::ValuesIn(all_valid_cases));
1524
1525 enum class BinaryExprSide { Left, Right, Both };
1526 using Expr_Binary_Test_WithAlias_Valid =
1527 ResolverTestWithParam<std::tuple<Params, BinaryExprSide>>;
TEST_P(Expr_Binary_Test_WithAlias_Valid,All)1528 TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
1529 const Params& params = std::get<0>(GetParam());
1530 BinaryExprSide side = std::get<1>(GetParam());
1531
1532 auto* create_lhs_type =
1533 (side == BinaryExprSide::Left || side == BinaryExprSide::Both)
1534 ? params.create_lhs_alias_type
1535 : params.create_lhs_type;
1536 auto* create_rhs_type =
1537 (side == BinaryExprSide::Right || side == BinaryExprSide::Both)
1538 ? params.create_rhs_alias_type
1539 : params.create_rhs_type;
1540
1541 auto* lhs_type = create_lhs_type(*this);
1542 auto* rhs_type = create_rhs_type(*this);
1543
1544 std::stringstream ss;
1545 ss << FriendlyName(lhs_type) << " " << params.op << " "
1546 << FriendlyName(rhs_type);
1547
1548 ss << ", After aliasing: " << FriendlyName(lhs_type) << " " << params.op
1549 << " " << FriendlyName(rhs_type);
1550 SCOPED_TRACE(ss.str());
1551
1552 Global("lhs", lhs_type, ast::StorageClass::kPrivate);
1553 Global("rhs", rhs_type, ast::StorageClass::kPrivate);
1554
1555 auto* expr =
1556 create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
1557 WrapInFunction(expr);
1558
1559 ASSERT_TRUE(r()->Resolve()) << r()->error();
1560 ASSERT_NE(TypeOf(expr), nullptr);
1561 // TODO(amaiorano): Bring this back once we have a way to get the canonical
1562 // type
1563 // auto* *result_type = params.create_result_type(*this);
1564 // ASSERT_TRUE(TypeOf(expr) == result_type);
1565 }
1566 INSTANTIATE_TEST_SUITE_P(
1567 ResolverTest,
1568 Expr_Binary_Test_WithAlias_Valid,
1569 testing::Combine(testing::ValuesIn(all_valid_cases),
1570 testing::Values(BinaryExprSide::Left,
1571 BinaryExprSide::Right,
1572 BinaryExprSide::Both)));
1573
1574 // This test works by taking the cartesian product of all possible
1575 // (type * type * op), and processing only the triplets that are not found in
1576 // the `all_valid_cases` table.
1577 using Expr_Binary_Test_Invalid =
1578 ResolverTestWithParam<std::tuple<builder::ast_type_func_ptr,
1579 builder::ast_type_func_ptr,
1580 ast::BinaryOp>>;
TEST_P(Expr_Binary_Test_Invalid,All)1581 TEST_P(Expr_Binary_Test_Invalid, All) {
1582 const builder::ast_type_func_ptr& lhs_create_type_func =
1583 std::get<0>(GetParam());
1584 const builder::ast_type_func_ptr& rhs_create_type_func =
1585 std::get<1>(GetParam());
1586 const ast::BinaryOp op = std::get<2>(GetParam());
1587
1588 // Skip if valid case
1589 // TODO(amaiorano): replace linear lookup with O(1) if too slow
1590 for (auto& c : all_valid_cases) {
1591 if (c.create_lhs_type == lhs_create_type_func &&
1592 c.create_rhs_type == rhs_create_type_func && c.op == op) {
1593 return;
1594 }
1595 }
1596
1597 auto* lhs_type = lhs_create_type_func(*this);
1598 auto* rhs_type = rhs_create_type_func(*this);
1599
1600 std::stringstream ss;
1601 ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type);
1602 SCOPED_TRACE(ss.str());
1603
1604 Global("lhs", lhs_type, ast::StorageClass::kPrivate);
1605 Global("rhs", rhs_type, ast::StorageClass::kPrivate);
1606
1607 auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, op, Expr("lhs"),
1608 Expr("rhs"));
1609 WrapInFunction(expr);
1610
1611 ASSERT_FALSE(r()->Resolve());
1612 ASSERT_EQ(r()->error(),
1613 "12:34 error: Binary expression operand types are invalid for "
1614 "this operation: " +
1615 FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) +
1616 " " + FriendlyName(rhs_type));
1617 }
1618 INSTANTIATE_TEST_SUITE_P(
1619 ResolverTest,
1620 Expr_Binary_Test_Invalid,
1621 testing::Combine(testing::ValuesIn(all_create_type_funcs),
1622 testing::ValuesIn(all_create_type_funcs),
1623 testing::ValuesIn(all_ops)));
1624
1625 using Expr_Binary_Test_Invalid_VectorMatrixMultiply =
1626 ResolverTestWithParam<std::tuple<bool, uint32_t, uint32_t, uint32_t>>;
TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply,All)1627 TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
1628 bool vec_by_mat = std::get<0>(GetParam());
1629 uint32_t vec_size = std::get<1>(GetParam());
1630 uint32_t mat_rows = std::get<2>(GetParam());
1631 uint32_t mat_cols = std::get<3>(GetParam());
1632
1633 const ast::Type* lhs_type = nullptr;
1634 const ast::Type* rhs_type = nullptr;
1635 const sem::Type* result_type = nullptr;
1636 bool is_valid_expr;
1637
1638 if (vec_by_mat) {
1639 lhs_type = ty.vec<f32>(vec_size);
1640 rhs_type = ty.mat<f32>(mat_cols, mat_rows);
1641 result_type = create<sem::Vector>(create<sem::F32>(), mat_cols);
1642 is_valid_expr = vec_size == mat_rows;
1643 } else {
1644 lhs_type = ty.mat<f32>(mat_cols, mat_rows);
1645 rhs_type = ty.vec<f32>(vec_size);
1646 result_type = create<sem::Vector>(create<sem::F32>(), mat_rows);
1647 is_valid_expr = vec_size == mat_cols;
1648 }
1649
1650 Global("lhs", lhs_type, ast::StorageClass::kPrivate);
1651 Global("rhs", rhs_type, ast::StorageClass::kPrivate);
1652
1653 auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
1654 WrapInFunction(expr);
1655
1656 if (is_valid_expr) {
1657 ASSERT_TRUE(r()->Resolve()) << r()->error();
1658 ASSERT_TRUE(TypeOf(expr) == result_type);
1659 } else {
1660 ASSERT_FALSE(r()->Resolve());
1661 ASSERT_EQ(r()->error(),
1662 "12:34 error: Binary expression operand types are invalid for "
1663 "this operation: " +
1664 FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) +
1665 " " + FriendlyName(rhs_type));
1666 }
1667 }
1668 auto all_dimension_values = testing::Values(2u, 3u, 4u);
1669 INSTANTIATE_TEST_SUITE_P(ResolverTest,
1670 Expr_Binary_Test_Invalid_VectorMatrixMultiply,
1671 testing::Combine(testing::Values(true, false),
1672 all_dimension_values,
1673 all_dimension_values,
1674 all_dimension_values));
1675
1676 using Expr_Binary_Test_Invalid_MatrixMatrixMultiply =
1677 ResolverTestWithParam<std::tuple<uint32_t, uint32_t, uint32_t, uint32_t>>;
TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply,All)1678 TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
1679 uint32_t lhs_mat_rows = std::get<0>(GetParam());
1680 uint32_t lhs_mat_cols = std::get<1>(GetParam());
1681 uint32_t rhs_mat_rows = std::get<2>(GetParam());
1682 uint32_t rhs_mat_cols = std::get<3>(GetParam());
1683
1684 auto* lhs_type = ty.mat<f32>(lhs_mat_cols, lhs_mat_rows);
1685 auto* rhs_type = ty.mat<f32>(rhs_mat_cols, rhs_mat_rows);
1686
1687 auto* f32 = create<sem::F32>();
1688 auto* col = create<sem::Vector>(f32, lhs_mat_rows);
1689 auto* result_type = create<sem::Matrix>(col, rhs_mat_cols);
1690
1691 Global("lhs", lhs_type, ast::StorageClass::kPrivate);
1692 Global("rhs", rhs_type, ast::StorageClass::kPrivate);
1693
1694 auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
1695 WrapInFunction(expr);
1696
1697 bool is_valid_expr = lhs_mat_cols == rhs_mat_rows;
1698 if (is_valid_expr) {
1699 ASSERT_TRUE(r()->Resolve()) << r()->error();
1700 ASSERT_TRUE(TypeOf(expr) == result_type);
1701 } else {
1702 ASSERT_FALSE(r()->Resolve());
1703 ASSERT_EQ(r()->error(),
1704 "12:34 error: Binary expression operand types are invalid for "
1705 "this operation: " +
1706 FriendlyName(lhs_type) + " " + ast::FriendlyName(expr->op) +
1707 " " + FriendlyName(rhs_type));
1708 }
1709 }
1710 INSTANTIATE_TEST_SUITE_P(ResolverTest,
1711 Expr_Binary_Test_Invalid_MatrixMatrixMultiply,
1712 testing::Combine(all_dimension_values,
1713 all_dimension_values,
1714 all_dimension_values,
1715 all_dimension_values));
1716
1717 } // namespace ExprBinaryTest
1718
1719 using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
TEST_P(UnaryOpExpressionTest,Expr_UnaryOp)1720 TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) {
1721 auto op = GetParam();
1722
1723 if (op == ast::UnaryOp::kNot) {
1724 Global("ident", ty.vec4<bool>(), ast::StorageClass::kPrivate);
1725 } else if (op == ast::UnaryOp::kNegation || op == ast::UnaryOp::kComplement) {
1726 Global("ident", ty.vec4<i32>(), ast::StorageClass::kPrivate);
1727 } else {
1728 Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
1729 }
1730 auto* der = create<ast::UnaryOpExpression>(op, Expr("ident"));
1731 WrapInFunction(der);
1732
1733 EXPECT_TRUE(r()->Resolve()) << r()->error();
1734
1735 ASSERT_NE(TypeOf(der), nullptr);
1736 ASSERT_TRUE(TypeOf(der)->Is<sem::Vector>());
1737 if (op == ast::UnaryOp::kNot) {
1738 EXPECT_TRUE(TypeOf(der)->As<sem::Vector>()->type()->Is<sem::Bool>());
1739 } else if (op == ast::UnaryOp::kNegation || op == ast::UnaryOp::kComplement) {
1740 EXPECT_TRUE(TypeOf(der)->As<sem::Vector>()->type()->Is<sem::I32>());
1741 } else {
1742 EXPECT_TRUE(TypeOf(der)->As<sem::Vector>()->type()->Is<sem::F32>());
1743 }
1744 EXPECT_EQ(TypeOf(der)->As<sem::Vector>()->Width(), 4u);
1745 }
1746 INSTANTIATE_TEST_SUITE_P(ResolverTest,
1747 UnaryOpExpressionTest,
1748 testing::Values(ast::UnaryOp::kComplement,
1749 ast::UnaryOp::kNegation,
1750 ast::UnaryOp::kNot));
1751
TEST_F(ResolverTest,StorageClass_SetsIfMissing)1752 TEST_F(ResolverTest, StorageClass_SetsIfMissing) {
1753 auto* var = Var("var", ty.i32());
1754
1755 auto* stmt = Decl(var);
1756 Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::DecorationList{});
1757
1758 EXPECT_TRUE(r()->Resolve()) << r()->error();
1759
1760 EXPECT_EQ(Sem().Get(var)->StorageClass(), ast::StorageClass::kFunction);
1761 }
1762
TEST_F(ResolverTest,StorageClass_SetForSampler)1763 TEST_F(ResolverTest, StorageClass_SetForSampler) {
1764 auto* t = ty.sampler(ast::SamplerKind::kSampler);
1765 auto* var = Global("var", t,
1766 ast::DecorationList{
1767 create<ast::BindingDecoration>(0),
1768 create<ast::GroupDecoration>(0),
1769 });
1770
1771 EXPECT_TRUE(r()->Resolve()) << r()->error();
1772
1773 EXPECT_EQ(Sem().Get(var)->StorageClass(),
1774 ast::StorageClass::kUniformConstant);
1775 }
1776
TEST_F(ResolverTest,StorageClass_SetForTexture)1777 TEST_F(ResolverTest, StorageClass_SetForTexture) {
1778 auto* t = ty.sampled_texture(ast::TextureDimension::k1d, ty.f32());
1779 auto* var = Global("var", t,
1780 ast::DecorationList{
1781 create<ast::BindingDecoration>(0),
1782 create<ast::GroupDecoration>(0),
1783 });
1784
1785 EXPECT_TRUE(r()->Resolve()) << r()->error();
1786
1787 EXPECT_EQ(Sem().Get(var)->StorageClass(),
1788 ast::StorageClass::kUniformConstant);
1789 }
1790
TEST_F(ResolverTest,StorageClass_DoesNotSetOnConst)1791 TEST_F(ResolverTest, StorageClass_DoesNotSetOnConst) {
1792 auto* var = Const("var", ty.i32(), Construct(ty.i32()));
1793 auto* stmt = Decl(var);
1794 Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::DecorationList{});
1795
1796 EXPECT_TRUE(r()->Resolve()) << r()->error();
1797
1798 EXPECT_EQ(Sem().Get(var)->StorageClass(), ast::StorageClass::kNone);
1799 }
1800
TEST_F(ResolverTest,Access_SetForStorageBuffer)1801 TEST_F(ResolverTest, Access_SetForStorageBuffer) {
1802 // [[block]] struct S { x : i32 };
1803 // var<storage> g : S;
1804 auto* s = Structure("S", {Member(Source{{12, 34}}, "x", ty.i32())},
1805 {create<ast::StructBlockDecoration>()});
1806 auto* var =
1807 Global(Source{{56, 78}}, "g", ty.Of(s), ast::StorageClass::kStorage,
1808 ast::DecorationList{
1809 create<ast::BindingDecoration>(0),
1810 create<ast::GroupDecoration>(0),
1811 });
1812
1813 EXPECT_TRUE(r()->Resolve()) << r()->error();
1814
1815 EXPECT_EQ(Sem().Get(var)->Access(), ast::Access::kRead);
1816 }
1817
TEST_F(ResolverTest,BindingPoint_SetForResources)1818 TEST_F(ResolverTest, BindingPoint_SetForResources) {
1819 // [[group(1), binding(2)]] var s1 : sampler;
1820 // [[group(3), binding(4)]] var s2 : sampler;
1821 auto* s1 = Global(Sym(), ty.sampler(ast::SamplerKind::kSampler),
1822 ast::DecorationList{create<ast::GroupDecoration>(1),
1823 create<ast::BindingDecoration>(2)});
1824 auto* s2 = Global(Sym(), ty.sampler(ast::SamplerKind::kSampler),
1825 ast::DecorationList{create<ast::GroupDecoration>(3),
1826 create<ast::BindingDecoration>(4)});
1827
1828 EXPECT_TRUE(r()->Resolve()) << r()->error();
1829
1830 EXPECT_EQ(Sem().Get<sem::GlobalVariable>(s1)->BindingPoint(),
1831 (sem::BindingPoint{1u, 2u}));
1832 EXPECT_EQ(Sem().Get<sem::GlobalVariable>(s2)->BindingPoint(),
1833 (sem::BindingPoint{3u, 4u}));
1834 }
1835
TEST_F(ResolverTest,Function_EntryPoints_StageDecoration)1836 TEST_F(ResolverTest, Function_EntryPoints_StageDecoration) {
1837 // fn b() {}
1838 // fn c() { b(); }
1839 // fn a() { c(); }
1840 // fn ep_1() { a(); b(); }
1841 // fn ep_2() { c();}
1842 //
1843 // c -> {ep_1, ep_2}
1844 // a -> {ep_1}
1845 // b -> {ep_1, ep_2}
1846 // ep_1 -> {}
1847 // ep_2 -> {}
1848
1849 Global("first", ty.f32(), ast::StorageClass::kPrivate);
1850 Global("second", ty.f32(), ast::StorageClass::kPrivate);
1851 Global("call_a", ty.f32(), ast::StorageClass::kPrivate);
1852 Global("call_b", ty.f32(), ast::StorageClass::kPrivate);
1853 Global("call_c", ty.f32(), ast::StorageClass::kPrivate);
1854
1855 ast::VariableList params;
1856 auto* func_b =
1857 Func("b", params, ty.f32(), {Return(0.0f)}, ast::DecorationList{});
1858 auto* func_c =
1859 Func("c", params, ty.f32(), {Assign("second", Call("b")), Return(0.0f)},
1860 ast::DecorationList{});
1861
1862 auto* func_a =
1863 Func("a", params, ty.f32(), {Assign("first", Call("c")), Return(0.0f)},
1864 ast::DecorationList{});
1865
1866 auto* ep_1 = Func("ep_1", params, ty.void_(),
1867 {
1868 Assign("call_a", Call("a")),
1869 Assign("call_b", Call("b")),
1870 },
1871 ast::DecorationList{Stage(ast::PipelineStage::kCompute),
1872 WorkgroupSize(1)});
1873
1874 auto* ep_2 = Func("ep_2", params, ty.void_(),
1875 {
1876 Assign("call_c", Call("c")),
1877 },
1878 ast::DecorationList{Stage(ast::PipelineStage::kCompute),
1879 WorkgroupSize(1)});
1880
1881 ASSERT_TRUE(r()->Resolve()) << r()->error();
1882
1883 auto* func_b_sem = Sem().Get(func_b);
1884 auto* func_a_sem = Sem().Get(func_a);
1885 auto* func_c_sem = Sem().Get(func_c);
1886 auto* ep_1_sem = Sem().Get(ep_1);
1887 auto* ep_2_sem = Sem().Get(ep_2);
1888 ASSERT_NE(func_b_sem, nullptr);
1889 ASSERT_NE(func_a_sem, nullptr);
1890 ASSERT_NE(func_c_sem, nullptr);
1891 ASSERT_NE(ep_1_sem, nullptr);
1892 ASSERT_NE(ep_2_sem, nullptr);
1893
1894 EXPECT_EQ(func_b_sem->Parameters().size(), 0u);
1895 EXPECT_EQ(func_a_sem->Parameters().size(), 0u);
1896 EXPECT_EQ(func_c_sem->Parameters().size(), 0u);
1897
1898 const auto& b_eps = func_b_sem->AncestorEntryPoints();
1899 ASSERT_EQ(2u, b_eps.size());
1900 EXPECT_EQ(Symbols().Register("ep_1"), b_eps[0]->Declaration()->symbol);
1901 EXPECT_EQ(Symbols().Register("ep_2"), b_eps[1]->Declaration()->symbol);
1902
1903 const auto& a_eps = func_a_sem->AncestorEntryPoints();
1904 ASSERT_EQ(1u, a_eps.size());
1905 EXPECT_EQ(Symbols().Register("ep_1"), a_eps[0]->Declaration()->symbol);
1906
1907 const auto& c_eps = func_c_sem->AncestorEntryPoints();
1908 ASSERT_EQ(2u, c_eps.size());
1909 EXPECT_EQ(Symbols().Register("ep_1"), c_eps[0]->Declaration()->symbol);
1910 EXPECT_EQ(Symbols().Register("ep_2"), c_eps[1]->Declaration()->symbol);
1911
1912 EXPECT_TRUE(ep_1_sem->AncestorEntryPoints().empty());
1913 EXPECT_TRUE(ep_2_sem->AncestorEntryPoints().empty());
1914 }
1915
1916 // Check for linear-time traversal of functions reachable from entry points.
1917 // See: crbug.com/tint/245
TEST_F(ResolverTest,Function_EntryPoints_LinearTime)1918 TEST_F(ResolverTest, Function_EntryPoints_LinearTime) {
1919 // fn lNa() { }
1920 // fn lNb() { }
1921 // ...
1922 // fn l2a() { l3a(); l3b(); }
1923 // fn l2b() { l3a(); l3b(); }
1924 // fn l1a() { l2a(); l2b(); }
1925 // fn l1b() { l2a(); l2b(); }
1926 // fn main() { l1a(); l1b(); }
1927
1928 static constexpr int levels = 64;
1929
1930 auto fn_a = [](int level) { return "l" + std::to_string(level + 1) + "a"; };
1931 auto fn_b = [](int level) { return "l" + std::to_string(level + 1) + "b"; };
1932
1933 Func(fn_a(levels), {}, ty.void_(), {}, {});
1934 Func(fn_b(levels), {}, ty.void_(), {}, {});
1935
1936 for (int i = levels - 1; i >= 0; i--) {
1937 Func(fn_a(i), {}, ty.void_(),
1938 {
1939 CallStmt(Call(fn_a(i + 1))),
1940 CallStmt(Call(fn_b(i + 1))),
1941 },
1942 {});
1943 Func(fn_b(i), {}, ty.void_(),
1944 {
1945 CallStmt(Call(fn_a(i + 1))),
1946 CallStmt(Call(fn_b(i + 1))),
1947 },
1948 {});
1949 }
1950
1951 Func("main", {}, ty.void_(),
1952 {
1953 CallStmt(Call(fn_a(0))),
1954 CallStmt(Call(fn_b(0))),
1955 },
1956 {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
1957
1958 ASSERT_TRUE(r()->Resolve()) << r()->error();
1959 }
1960
1961 // Test for crbug.com/tint/728
TEST_F(ResolverTest,ASTNodesAreReached)1962 TEST_F(ResolverTest, ASTNodesAreReached) {
1963 Structure("A", {Member("x", ty.array<f32, 4>(4))});
1964 Structure("B", {Member("x", ty.array<f32, 4>(4))});
1965 ASSERT_TRUE(r()->Resolve()) << r()->error();
1966 }
1967
TEST_F(ResolverTest,ASTNodeNotReached)1968 TEST_F(ResolverTest, ASTNodeNotReached) {
1969 EXPECT_FATAL_FAILURE(
1970 {
1971 ProgramBuilder b;
1972 b.Expr("expr");
1973 Resolver(&b).Resolve();
1974 },
1975 "internal compiler error: AST node 'tint::ast::IdentifierExpression' was "
1976 "not reached by the resolver");
1977 }
1978
TEST_F(ResolverTest,ASTNodeReachedTwice)1979 TEST_F(ResolverTest, ASTNodeReachedTwice) {
1980 EXPECT_FATAL_FAILURE(
1981 {
1982 ProgramBuilder b;
1983 auto* expr = b.Expr(1);
1984 b.Global("a", b.ty.i32(), ast::StorageClass::kPrivate, expr);
1985 b.Global("b", b.ty.i32(), ast::StorageClass::kPrivate, expr);
1986 Resolver(&b).Resolve();
1987 },
1988 "internal compiler error: AST node 'tint::ast::SintLiteralExpression' "
1989 "was encountered twice in the same AST of a Program");
1990 }
1991
TEST_F(ResolverTest,UnaryOp_Not)1992 TEST_F(ResolverTest, UnaryOp_Not) {
1993 Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
1994 auto* der = create<ast::UnaryOpExpression>(ast::UnaryOp::kNot,
1995 Expr(Source{{12, 34}}, "ident"));
1996 WrapInFunction(der);
1997
1998 EXPECT_FALSE(r()->Resolve());
1999 EXPECT_EQ(r()->error(),
2000 "12:34 error: cannot logical negate expression of type 'vec4<f32>");
2001 }
2002
TEST_F(ResolverTest,UnaryOp_Complement)2003 TEST_F(ResolverTest, UnaryOp_Complement) {
2004 Global("ident", ty.vec4<f32>(), ast::StorageClass::kPrivate);
2005 auto* der = create<ast::UnaryOpExpression>(ast::UnaryOp::kComplement,
2006 Expr(Source{{12, 34}}, "ident"));
2007 WrapInFunction(der);
2008
2009 EXPECT_FALSE(r()->Resolve());
2010 EXPECT_EQ(
2011 r()->error(),
2012 "12:34 error: cannot bitwise complement expression of type 'vec4<f32>");
2013 }
2014
TEST_F(ResolverTest,UnaryOp_Negation)2015 TEST_F(ResolverTest, UnaryOp_Negation) {
2016 Global("ident", ty.u32(), ast::StorageClass::kPrivate);
2017 auto* der = create<ast::UnaryOpExpression>(ast::UnaryOp::kNegation,
2018 Expr(Source{{12, 34}}, "ident"));
2019 WrapInFunction(der);
2020
2021 EXPECT_FALSE(r()->Resolve());
2022 EXPECT_EQ(r()->error(), "12:34 error: cannot negate expression of type 'u32");
2023 }
2024 } // namespace
2025 } // namespace resolver
2026 } // namespace tint
2027