1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "gmock/gmock.h"
16 #include "src/ast/call_statement.h"
17 #include "src/ast/stage_decoration.h"
18 #include "src/sem/call.h"
19 #include "src/writer/hlsl/test_helper.h"
20
21 namespace tint {
22 namespace writer {
23 namespace hlsl {
24 namespace {
25
26 using IntrinsicType = sem::IntrinsicType;
27
28 using ::testing::HasSubstr;
29
30 using HlslGeneratorImplTest_Intrinsic = TestHelper;
31
32 enum class ParamType {
33 kF32,
34 kU32,
35 kBool,
36 };
37
38 struct IntrinsicData {
39 IntrinsicType intrinsic;
40 ParamType type;
41 const char* hlsl_name;
42 };
operator <<(std::ostream & out,IntrinsicData data)43 inline std::ostream& operator<<(std::ostream& out, IntrinsicData data) {
44 out << data.hlsl_name;
45 switch (data.type) {
46 case ParamType::kF32:
47 out << "f32";
48 break;
49 case ParamType::kU32:
50 out << "u32";
51 break;
52 case ParamType::kBool:
53 out << "bool";
54 break;
55 }
56 out << ">";
57 return out;
58 }
59
GenerateCall(IntrinsicType intrinsic,ParamType type,ProgramBuilder * builder)60 const ast::CallExpression* GenerateCall(IntrinsicType intrinsic,
61 ParamType type,
62 ProgramBuilder* builder) {
63 std::string name;
64 std::ostringstream str(name);
65 str << intrinsic;
66 switch (intrinsic) {
67 case IntrinsicType::kAcos:
68 case IntrinsicType::kAsin:
69 case IntrinsicType::kAtan:
70 case IntrinsicType::kCeil:
71 case IntrinsicType::kCos:
72 case IntrinsicType::kCosh:
73 case IntrinsicType::kDpdx:
74 case IntrinsicType::kDpdxCoarse:
75 case IntrinsicType::kDpdxFine:
76 case IntrinsicType::kDpdy:
77 case IntrinsicType::kDpdyCoarse:
78 case IntrinsicType::kDpdyFine:
79 case IntrinsicType::kExp:
80 case IntrinsicType::kExp2:
81 case IntrinsicType::kFloor:
82 case IntrinsicType::kFract:
83 case IntrinsicType::kFwidth:
84 case IntrinsicType::kFwidthCoarse:
85 case IntrinsicType::kFwidthFine:
86 case IntrinsicType::kInverseSqrt:
87 case IntrinsicType::kIsFinite:
88 case IntrinsicType::kIsInf:
89 case IntrinsicType::kIsNan:
90 case IntrinsicType::kIsNormal:
91 case IntrinsicType::kLength:
92 case IntrinsicType::kLog:
93 case IntrinsicType::kLog2:
94 case IntrinsicType::kNormalize:
95 case IntrinsicType::kRound:
96 case IntrinsicType::kSin:
97 case IntrinsicType::kSinh:
98 case IntrinsicType::kSqrt:
99 case IntrinsicType::kTan:
100 case IntrinsicType::kTanh:
101 case IntrinsicType::kTrunc:
102 case IntrinsicType::kSign:
103 return builder->Call(str.str(), "f2");
104 case IntrinsicType::kLdexp:
105 return builder->Call(str.str(), "f2", "i2");
106 case IntrinsicType::kAtan2:
107 case IntrinsicType::kDot:
108 case IntrinsicType::kDistance:
109 case IntrinsicType::kPow:
110 case IntrinsicType::kReflect:
111 case IntrinsicType::kStep:
112 return builder->Call(str.str(), "f2", "f2");
113 case IntrinsicType::kCross:
114 return builder->Call(str.str(), "f3", "f3");
115 case IntrinsicType::kFma:
116 case IntrinsicType::kMix:
117 case IntrinsicType::kFaceForward:
118 case IntrinsicType::kSmoothStep:
119 return builder->Call(str.str(), "f2", "f2", "f2");
120 case IntrinsicType::kAll:
121 case IntrinsicType::kAny:
122 return builder->Call(str.str(), "b2");
123 case IntrinsicType::kAbs:
124 if (type == ParamType::kF32) {
125 return builder->Call(str.str(), "f2");
126 } else {
127 return builder->Call(str.str(), "u2");
128 }
129 case IntrinsicType::kCountOneBits:
130 case IntrinsicType::kReverseBits:
131 return builder->Call(str.str(), "u2");
132 case IntrinsicType::kMax:
133 case IntrinsicType::kMin:
134 if (type == ParamType::kF32) {
135 return builder->Call(str.str(), "f2", "f2");
136 } else {
137 return builder->Call(str.str(), "u2", "u2");
138 }
139 case IntrinsicType::kClamp:
140 if (type == ParamType::kF32) {
141 return builder->Call(str.str(), "f2", "f2", "f2");
142 } else {
143 return builder->Call(str.str(), "u2", "u2", "u2");
144 }
145 case IntrinsicType::kSelect:
146 return builder->Call(str.str(), "f2", "f2", "b2");
147 case IntrinsicType::kDeterminant:
148 return builder->Call(str.str(), "m2x2");
149 case IntrinsicType::kTranspose:
150 return builder->Call(str.str(), "m3x2");
151 default:
152 break;
153 }
154 return nullptr;
155 }
156 using HlslIntrinsicTest = TestParamHelper<IntrinsicData>;
TEST_P(HlslIntrinsicTest,Emit)157 TEST_P(HlslIntrinsicTest, Emit) {
158 auto param = GetParam();
159
160 Global("f2", ty.vec2<f32>(), ast::StorageClass::kPrivate);
161 Global("f3", ty.vec3<f32>(), ast::StorageClass::kPrivate);
162 Global("u2", ty.vec2<u32>(), ast::StorageClass::kPrivate);
163 Global("i2", ty.vec2<i32>(), ast::StorageClass::kPrivate);
164 Global("b2", ty.vec2<bool>(), ast::StorageClass::kPrivate);
165 Global("m2x2", ty.mat2x2<f32>(), ast::StorageClass::kPrivate);
166 Global("m3x2", ty.mat3x2<f32>(), ast::StorageClass::kPrivate);
167
168 auto* call = GenerateCall(param.intrinsic, param.type, this);
169 ASSERT_NE(nullptr, call) << "Unhandled intrinsic";
170 Func("func", {}, ty.void_(), {CallStmt(call)},
171 {create<ast::StageDecoration>(ast::PipelineStage::kFragment)});
172
173 GeneratorImpl& gen = Build();
174
175 auto* sem = program->Sem().Get(call);
176 ASSERT_NE(sem, nullptr);
177 auto* target = sem->Target();
178 ASSERT_NE(target, nullptr);
179 auto* intrinsic = target->As<sem::Intrinsic>();
180 ASSERT_NE(intrinsic, nullptr);
181
182 EXPECT_EQ(gen.generate_builtin_name(intrinsic), param.hlsl_name);
183 }
184 INSTANTIATE_TEST_SUITE_P(
185 HlslGeneratorImplTest_Intrinsic,
186 HlslIntrinsicTest,
187 testing::Values(
188 IntrinsicData{IntrinsicType::kAbs, ParamType::kF32, "abs"},
189 IntrinsicData{IntrinsicType::kAbs, ParamType::kU32, "abs"},
190 IntrinsicData{IntrinsicType::kAcos, ParamType::kF32, "acos"},
191 IntrinsicData{IntrinsicType::kAll, ParamType::kBool, "all"},
192 IntrinsicData{IntrinsicType::kAny, ParamType::kBool, "any"},
193 IntrinsicData{IntrinsicType::kAsin, ParamType::kF32, "asin"},
194 IntrinsicData{IntrinsicType::kAtan, ParamType::kF32, "atan"},
195 IntrinsicData{IntrinsicType::kAtan2, ParamType::kF32, "atan2"},
196 IntrinsicData{IntrinsicType::kCeil, ParamType::kF32, "ceil"},
197 IntrinsicData{IntrinsicType::kClamp, ParamType::kF32, "clamp"},
198 IntrinsicData{IntrinsicType::kClamp, ParamType::kU32, "clamp"},
199 IntrinsicData{IntrinsicType::kCos, ParamType::kF32, "cos"},
200 IntrinsicData{IntrinsicType::kCosh, ParamType::kF32, "cosh"},
201 IntrinsicData{IntrinsicType::kCountOneBits, ParamType::kU32,
202 "countbits"},
203 IntrinsicData{IntrinsicType::kCross, ParamType::kF32, "cross"},
204 IntrinsicData{IntrinsicType::kDeterminant, ParamType::kF32,
205 "determinant"},
206 IntrinsicData{IntrinsicType::kDistance, ParamType::kF32, "distance"},
207 IntrinsicData{IntrinsicType::kDot, ParamType::kF32, "dot"},
208 IntrinsicData{IntrinsicType::kDpdx, ParamType::kF32, "ddx"},
209 IntrinsicData{IntrinsicType::kDpdxCoarse, ParamType::kF32,
210 "ddx_coarse"},
211 IntrinsicData{IntrinsicType::kDpdxFine, ParamType::kF32, "ddx_fine"},
212 IntrinsicData{IntrinsicType::kDpdy, ParamType::kF32, "ddy"},
213 IntrinsicData{IntrinsicType::kDpdyCoarse, ParamType::kF32,
214 "ddy_coarse"},
215 IntrinsicData{IntrinsicType::kDpdyFine, ParamType::kF32, "ddy_fine"},
216 IntrinsicData{IntrinsicType::kExp, ParamType::kF32, "exp"},
217 IntrinsicData{IntrinsicType::kExp2, ParamType::kF32, "exp2"},
218 IntrinsicData{IntrinsicType::kFaceForward, ParamType::kF32,
219 "faceforward"},
220 IntrinsicData{IntrinsicType::kFloor, ParamType::kF32, "floor"},
221 IntrinsicData{IntrinsicType::kFma, ParamType::kF32, "mad"},
222 IntrinsicData{IntrinsicType::kFract, ParamType::kF32, "frac"},
223 IntrinsicData{IntrinsicType::kFwidth, ParamType::kF32, "fwidth"},
224 IntrinsicData{IntrinsicType::kFwidthCoarse, ParamType::kF32, "fwidth"},
225 IntrinsicData{IntrinsicType::kFwidthFine, ParamType::kF32, "fwidth"},
226 IntrinsicData{IntrinsicType::kInverseSqrt, ParamType::kF32, "rsqrt"},
227 IntrinsicData{IntrinsicType::kIsFinite, ParamType::kF32, "isfinite"},
228 IntrinsicData{IntrinsicType::kIsInf, ParamType::kF32, "isinf"},
229 IntrinsicData{IntrinsicType::kIsNan, ParamType::kF32, "isnan"},
230 IntrinsicData{IntrinsicType::kLdexp, ParamType::kF32, "ldexp"},
231 IntrinsicData{IntrinsicType::kLength, ParamType::kF32, "length"},
232 IntrinsicData{IntrinsicType::kLog, ParamType::kF32, "log"},
233 IntrinsicData{IntrinsicType::kLog2, ParamType::kF32, "log2"},
234 IntrinsicData{IntrinsicType::kMax, ParamType::kF32, "max"},
235 IntrinsicData{IntrinsicType::kMax, ParamType::kU32, "max"},
236 IntrinsicData{IntrinsicType::kMin, ParamType::kF32, "min"},
237 IntrinsicData{IntrinsicType::kMin, ParamType::kU32, "min"},
238 IntrinsicData{IntrinsicType::kMix, ParamType::kF32, "lerp"},
239 IntrinsicData{IntrinsicType::kNormalize, ParamType::kF32, "normalize"},
240 IntrinsicData{IntrinsicType::kPow, ParamType::kF32, "pow"},
241 IntrinsicData{IntrinsicType::kReflect, ParamType::kF32, "reflect"},
242 IntrinsicData{IntrinsicType::kReverseBits, ParamType::kU32,
243 "reversebits"},
244 IntrinsicData{IntrinsicType::kRound, ParamType::kU32, "round"},
245 IntrinsicData{IntrinsicType::kSign, ParamType::kF32, "sign"},
246 IntrinsicData{IntrinsicType::kSin, ParamType::kF32, "sin"},
247 IntrinsicData{IntrinsicType::kSinh, ParamType::kF32, "sinh"},
248 IntrinsicData{IntrinsicType::kSmoothStep, ParamType::kF32,
249 "smoothstep"},
250 IntrinsicData{IntrinsicType::kSqrt, ParamType::kF32, "sqrt"},
251 IntrinsicData{IntrinsicType::kStep, ParamType::kF32, "step"},
252 IntrinsicData{IntrinsicType::kTan, ParamType::kF32, "tan"},
253 IntrinsicData{IntrinsicType::kTanh, ParamType::kF32, "tanh"},
254 IntrinsicData{IntrinsicType::kTranspose, ParamType::kF32, "transpose"},
255 IntrinsicData{IntrinsicType::kTrunc, ParamType::kF32, "trunc"}));
256
TEST_F(HlslGeneratorImplTest_Intrinsic,DISABLED_Intrinsic_IsNormal)257 TEST_F(HlslGeneratorImplTest_Intrinsic, DISABLED_Intrinsic_IsNormal) {
258 FAIL();
259 }
260
TEST_F(HlslGeneratorImplTest_Intrinsic,Intrinsic_Call)261 TEST_F(HlslGeneratorImplTest_Intrinsic, Intrinsic_Call) {
262 auto* call = Call("dot", "param1", "param2");
263
264 Global("param1", ty.vec3<f32>(), ast::StorageClass::kPrivate);
265 Global("param2", ty.vec3<f32>(), ast::StorageClass::kPrivate);
266
267 WrapInFunction(CallStmt(call));
268
269 GeneratorImpl& gen = Build();
270
271 gen.increment_indent();
272 std::stringstream out;
273 ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
274 EXPECT_EQ(out.str(), "dot(param1, param2)");
275 }
276
TEST_F(HlslGeneratorImplTest_Intrinsic,Select_Scalar)277 TEST_F(HlslGeneratorImplTest_Intrinsic, Select_Scalar) {
278 auto* call = Call("select", 1.0f, 2.0f, true);
279 WrapInFunction(CallStmt(call));
280 GeneratorImpl& gen = Build();
281
282 gen.increment_indent();
283 std::stringstream out;
284 ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
285 EXPECT_EQ(out.str(), "(true ? 2.0f : 1.0f)");
286 }
287
TEST_F(HlslGeneratorImplTest_Intrinsic,Select_Vector)288 TEST_F(HlslGeneratorImplTest_Intrinsic, Select_Vector) {
289 auto* call =
290 Call("select", vec2<i32>(1, 2), vec2<i32>(3, 4), vec2<bool>(true, false));
291 WrapInFunction(CallStmt(call));
292 GeneratorImpl& gen = Build();
293
294 gen.increment_indent();
295 std::stringstream out;
296 ASSERT_TRUE(gen.EmitExpression(out, call)) << gen.error();
297 EXPECT_EQ(out.str(), "(bool2(true, false) ? int2(3, 4) : int2(1, 2))");
298 }
299
TEST_F(HlslGeneratorImplTest_Intrinsic,Modf_Scalar)300 TEST_F(HlslGeneratorImplTest_Intrinsic, Modf_Scalar) {
301 auto* call = Call("modf", 1.0f);
302 WrapInFunction(CallStmt(call));
303
304 GeneratorImpl& gen = SanitizeAndBuild();
305
306 ASSERT_TRUE(gen.Generate()) << gen.error();
307 EXPECT_EQ(gen.result(), R"(struct modf_result {
308 float fract;
309 float whole;
310 };
311 modf_result tint_modf(float param_0) {
312 float whole;
313 float fract = modf(param_0, whole);
314 modf_result result = {fract, whole};
315 return result;
316 }
317
318 [numthreads(1, 1, 1)]
319 void test_function() {
320 tint_modf(1.0f);
321 return;
322 }
323 )");
324 }
325
TEST_F(HlslGeneratorImplTest_Intrinsic,Modf_Vector)326 TEST_F(HlslGeneratorImplTest_Intrinsic, Modf_Vector) {
327 auto* call = Call("modf", vec3<f32>());
328 WrapInFunction(CallStmt(call));
329
330 GeneratorImpl& gen = SanitizeAndBuild();
331
332 ASSERT_TRUE(gen.Generate()) << gen.error();
333 EXPECT_EQ(gen.result(), R"(struct modf_result_vec3 {
334 float3 fract;
335 float3 whole;
336 };
337 modf_result_vec3 tint_modf(float3 param_0) {
338 float3 whole;
339 float3 fract = modf(param_0, whole);
340 modf_result_vec3 result = {fract, whole};
341 return result;
342 }
343
344 [numthreads(1, 1, 1)]
345 void test_function() {
346 tint_modf(float3(0.0f, 0.0f, 0.0f));
347 return;
348 }
349 )");
350 }
351
TEST_F(HlslGeneratorImplTest_Intrinsic,Frexp_Scalar_i32)352 TEST_F(HlslGeneratorImplTest_Intrinsic, Frexp_Scalar_i32) {
353 auto* call = Call("frexp", 1.0f);
354 WrapInFunction(CallStmt(call));
355
356 GeneratorImpl& gen = SanitizeAndBuild();
357
358 ASSERT_TRUE(gen.Generate()) << gen.error();
359 EXPECT_EQ(gen.result(), R"(struct frexp_result {
360 float sig;
361 int exp;
362 };
363 frexp_result tint_frexp(float param_0) {
364 float exp;
365 float sig = frexp(param_0, exp);
366 frexp_result result = {sig, int(exp)};
367 return result;
368 }
369
370 [numthreads(1, 1, 1)]
371 void test_function() {
372 tint_frexp(1.0f);
373 return;
374 }
375 )");
376 }
377
TEST_F(HlslGeneratorImplTest_Intrinsic,Frexp_Vector_i32)378 TEST_F(HlslGeneratorImplTest_Intrinsic, Frexp_Vector_i32) {
379 auto* call = Call("frexp", vec3<f32>());
380 WrapInFunction(CallStmt(call));
381
382 GeneratorImpl& gen = SanitizeAndBuild();
383
384 ASSERT_TRUE(gen.Generate()) << gen.error();
385 EXPECT_EQ(gen.result(), R"(struct frexp_result_vec3 {
386 float3 sig;
387 int3 exp;
388 };
389 frexp_result_vec3 tint_frexp(float3 param_0) {
390 float3 exp;
391 float3 sig = frexp(param_0, exp);
392 frexp_result_vec3 result = {sig, int3(exp)};
393 return result;
394 }
395
396 [numthreads(1, 1, 1)]
397 void test_function() {
398 tint_frexp(float3(0.0f, 0.0f, 0.0f));
399 return;
400 }
401 )");
402 }
403
TEST_F(HlslGeneratorImplTest_Intrinsic,IsNormal_Scalar)404 TEST_F(HlslGeneratorImplTest_Intrinsic, IsNormal_Scalar) {
405 auto* val = Var("val", ty.f32());
406 auto* call = Call("isNormal", val);
407 WrapInFunction(val, call);
408
409 GeneratorImpl& gen = SanitizeAndBuild();
410
411 ASSERT_TRUE(gen.Generate()) << gen.error();
412 EXPECT_EQ(gen.result(), R"(bool tint_isNormal(float param_0) {
413 uint exponent = asuint(param_0) & 0x7f80000;
414 uint clamped = clamp(exponent, 0x0080000, 0x7f00000);
415 return clamped == exponent;
416 }
417
418 [numthreads(1, 1, 1)]
419 void test_function() {
420 float val = 0.0f;
421 const bool tint_symbol = tint_isNormal(val);
422 return;
423 }
424 )");
425 }
426
TEST_F(HlslGeneratorImplTest_Intrinsic,IsNormal_Vector)427 TEST_F(HlslGeneratorImplTest_Intrinsic, IsNormal_Vector) {
428 auto* val = Var("val", ty.vec3<f32>());
429 auto* call = Call("isNormal", val);
430 WrapInFunction(val, call);
431
432 GeneratorImpl& gen = SanitizeAndBuild();
433
434 ASSERT_TRUE(gen.Generate()) << gen.error();
435 EXPECT_EQ(gen.result(), R"(bool3 tint_isNormal(float3 param_0) {
436 uint3 exponent = asuint(param_0) & 0x7f80000;
437 uint3 clamped = clamp(exponent, 0x0080000, 0x7f00000);
438 return clamped == exponent;
439 }
440
441 [numthreads(1, 1, 1)]
442 void test_function() {
443 float3 val = float3(0.0f, 0.0f, 0.0f);
444 const bool3 tint_symbol = tint_isNormal(val);
445 return;
446 }
447 )");
448 }
449
TEST_F(HlslGeneratorImplTest_Intrinsic,Pack4x8Snorm)450 TEST_F(HlslGeneratorImplTest_Intrinsic, Pack4x8Snorm) {
451 auto* call = Call("pack4x8snorm", "p1");
452 Global("p1", ty.vec4<f32>(), ast::StorageClass::kPrivate);
453 WrapInFunction(CallStmt(call));
454 GeneratorImpl& gen = Build();
455
456 ASSERT_TRUE(gen.Generate()) << gen.error();
457 EXPECT_EQ(gen.result(), R"(uint tint_pack4x8snorm(float4 param_0) {
458 int4 i = int4(round(clamp(param_0, -1.0, 1.0) * 127.0)) & 0xff;
459 return asuint(i.x | i.y << 8 | i.z << 16 | i.w << 24);
460 }
461
462 static float4 p1 = float4(0.0f, 0.0f, 0.0f, 0.0f);
463
464 [numthreads(1, 1, 1)]
465 void test_function() {
466 tint_pack4x8snorm(p1);
467 return;
468 }
469 )");
470 }
471
TEST_F(HlslGeneratorImplTest_Intrinsic,Pack4x8Unorm)472 TEST_F(HlslGeneratorImplTest_Intrinsic, Pack4x8Unorm) {
473 auto* call = Call("pack4x8unorm", "p1");
474 Global("p1", ty.vec4<f32>(), ast::StorageClass::kPrivate);
475 WrapInFunction(CallStmt(call));
476 GeneratorImpl& gen = Build();
477
478 ASSERT_TRUE(gen.Generate()) << gen.error();
479 EXPECT_EQ(gen.result(), R"(uint tint_pack4x8unorm(float4 param_0) {
480 uint4 i = uint4(round(clamp(param_0, 0.0, 1.0) * 255.0));
481 return (i.x | i.y << 8 | i.z << 16 | i.w << 24);
482 }
483
484 static float4 p1 = float4(0.0f, 0.0f, 0.0f, 0.0f);
485
486 [numthreads(1, 1, 1)]
487 void test_function() {
488 tint_pack4x8unorm(p1);
489 return;
490 }
491 )");
492 }
493
TEST_F(HlslGeneratorImplTest_Intrinsic,Pack2x16Snorm)494 TEST_F(HlslGeneratorImplTest_Intrinsic, Pack2x16Snorm) {
495 auto* call = Call("pack2x16snorm", "p1");
496 Global("p1", ty.vec2<f32>(), ast::StorageClass::kPrivate);
497 WrapInFunction(CallStmt(call));
498 GeneratorImpl& gen = Build();
499
500 ASSERT_TRUE(gen.Generate()) << gen.error();
501 EXPECT_EQ(gen.result(), R"(uint tint_pack2x16snorm(float2 param_0) {
502 int2 i = int2(round(clamp(param_0, -1.0, 1.0) * 32767.0)) & 0xffff;
503 return asuint(i.x | i.y << 16);
504 }
505
506 static float2 p1 = float2(0.0f, 0.0f);
507
508 [numthreads(1, 1, 1)]
509 void test_function() {
510 tint_pack2x16snorm(p1);
511 return;
512 }
513 )");
514 }
515
TEST_F(HlslGeneratorImplTest_Intrinsic,Pack2x16Unorm)516 TEST_F(HlslGeneratorImplTest_Intrinsic, Pack2x16Unorm) {
517 auto* call = Call("pack2x16unorm", "p1");
518 Global("p1", ty.vec2<f32>(), ast::StorageClass::kPrivate);
519 WrapInFunction(CallStmt(call));
520 GeneratorImpl& gen = Build();
521
522 ASSERT_TRUE(gen.Generate()) << gen.error();
523 EXPECT_EQ(gen.result(), R"(uint tint_pack2x16unorm(float2 param_0) {
524 uint2 i = uint2(round(clamp(param_0, 0.0, 1.0) * 65535.0));
525 return (i.x | i.y << 16);
526 }
527
528 static float2 p1 = float2(0.0f, 0.0f);
529
530 [numthreads(1, 1, 1)]
531 void test_function() {
532 tint_pack2x16unorm(p1);
533 return;
534 }
535 )");
536 }
537
TEST_F(HlslGeneratorImplTest_Intrinsic,Pack2x16Float)538 TEST_F(HlslGeneratorImplTest_Intrinsic, Pack2x16Float) {
539 auto* call = Call("pack2x16float", "p1");
540 Global("p1", ty.vec2<f32>(), ast::StorageClass::kPrivate);
541 WrapInFunction(CallStmt(call));
542 GeneratorImpl& gen = Build();
543
544 ASSERT_TRUE(gen.Generate()) << gen.error();
545 EXPECT_EQ(gen.result(), R"(uint tint_pack2x16float(float2 param_0) {
546 uint2 i = f32tof16(param_0);
547 return i.x | (i.y << 16);
548 }
549
550 static float2 p1 = float2(0.0f, 0.0f);
551
552 [numthreads(1, 1, 1)]
553 void test_function() {
554 tint_pack2x16float(p1);
555 return;
556 }
557 )");
558 }
559
TEST_F(HlslGeneratorImplTest_Intrinsic,Unpack4x8Snorm)560 TEST_F(HlslGeneratorImplTest_Intrinsic, Unpack4x8Snorm) {
561 auto* call = Call("unpack4x8snorm", "p1");
562 Global("p1", ty.u32(), ast::StorageClass::kPrivate);
563 WrapInFunction(CallStmt(call));
564 GeneratorImpl& gen = Build();
565
566 ASSERT_TRUE(gen.Generate()) << gen.error();
567 EXPECT_EQ(gen.result(), R"(float4 tint_unpack4x8snorm(uint param_0) {
568 int j = int(param_0);
569 int4 i = int4(j << 24, j << 16, j << 8, j) >> 24;
570 return clamp(float4(i) / 127.0, -1.0, 1.0);
571 }
572
573 static uint p1 = 0u;
574
575 [numthreads(1, 1, 1)]
576 void test_function() {
577 tint_unpack4x8snorm(p1);
578 return;
579 }
580 )");
581 }
582
TEST_F(HlslGeneratorImplTest_Intrinsic,Unpack4x8Unorm)583 TEST_F(HlslGeneratorImplTest_Intrinsic, Unpack4x8Unorm) {
584 auto* call = Call("unpack4x8unorm", "p1");
585 Global("p1", ty.u32(), ast::StorageClass::kPrivate);
586 WrapInFunction(CallStmt(call));
587 GeneratorImpl& gen = Build();
588
589 ASSERT_TRUE(gen.Generate()) << gen.error();
590 EXPECT_EQ(gen.result(), R"(float4 tint_unpack4x8unorm(uint param_0) {
591 uint j = param_0;
592 uint4 i = uint4(j & 0xff, (j >> 8) & 0xff, (j >> 16) & 0xff, j >> 24);
593 return float4(i) / 255.0;
594 }
595
596 static uint p1 = 0u;
597
598 [numthreads(1, 1, 1)]
599 void test_function() {
600 tint_unpack4x8unorm(p1);
601 return;
602 }
603 )");
604 }
605
TEST_F(HlslGeneratorImplTest_Intrinsic,Unpack2x16Snorm)606 TEST_F(HlslGeneratorImplTest_Intrinsic, Unpack2x16Snorm) {
607 auto* call = Call("unpack2x16snorm", "p1");
608 Global("p1", ty.u32(), ast::StorageClass::kPrivate);
609 WrapInFunction(CallStmt(call));
610 GeneratorImpl& gen = Build();
611
612 ASSERT_TRUE(gen.Generate()) << gen.error();
613 EXPECT_EQ(gen.result(), R"(float2 tint_unpack2x16snorm(uint param_0) {
614 int j = int(param_0);
615 int2 i = int2(j << 16, j) >> 16;
616 return clamp(float2(i) / 32767.0, -1.0, 1.0);
617 }
618
619 static uint p1 = 0u;
620
621 [numthreads(1, 1, 1)]
622 void test_function() {
623 tint_unpack2x16snorm(p1);
624 return;
625 }
626 )");
627 }
628
TEST_F(HlslGeneratorImplTest_Intrinsic,Unpack2x16Unorm)629 TEST_F(HlslGeneratorImplTest_Intrinsic, Unpack2x16Unorm) {
630 auto* call = Call("unpack2x16unorm", "p1");
631 Global("p1", ty.u32(), ast::StorageClass::kPrivate);
632 WrapInFunction(CallStmt(call));
633 GeneratorImpl& gen = Build();
634
635 ASSERT_TRUE(gen.Generate()) << gen.error();
636 EXPECT_EQ(gen.result(), R"(float2 tint_unpack2x16unorm(uint param_0) {
637 uint j = param_0;
638 uint2 i = uint2(j & 0xffff, j >> 16);
639 return float2(i) / 65535.0;
640 }
641
642 static uint p1 = 0u;
643
644 [numthreads(1, 1, 1)]
645 void test_function() {
646 tint_unpack2x16unorm(p1);
647 return;
648 }
649 )");
650 }
651
TEST_F(HlslGeneratorImplTest_Intrinsic,Unpack2x16Float)652 TEST_F(HlslGeneratorImplTest_Intrinsic, Unpack2x16Float) {
653 auto* call = Call("unpack2x16float", "p1");
654 Global("p1", ty.u32(), ast::StorageClass::kPrivate);
655 WrapInFunction(CallStmt(call));
656 GeneratorImpl& gen = Build();
657
658 ASSERT_TRUE(gen.Generate()) << gen.error();
659 EXPECT_EQ(gen.result(), R"(float2 tint_unpack2x16float(uint param_0) {
660 uint i = param_0;
661 return f16tof32(uint2(i & 0xffff, i >> 16));
662 }
663
664 static uint p1 = 0u;
665
666 [numthreads(1, 1, 1)]
667 void test_function() {
668 tint_unpack2x16float(p1);
669 return;
670 }
671 )");
672 }
673
TEST_F(HlslGeneratorImplTest_Intrinsic,StorageBarrier)674 TEST_F(HlslGeneratorImplTest_Intrinsic, StorageBarrier) {
675 Func("main", {}, ty.void_(), {CallStmt(Call("storageBarrier"))},
676 {
677 Stage(ast::PipelineStage::kCompute),
678 WorkgroupSize(1),
679 });
680
681 GeneratorImpl& gen = Build();
682
683 ASSERT_TRUE(gen.Generate()) << gen.error();
684 EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
685 void main() {
686 DeviceMemoryBarrierWithGroupSync();
687 return;
688 }
689 )");
690 }
691
TEST_F(HlslGeneratorImplTest_Intrinsic,WorkgroupBarrier)692 TEST_F(HlslGeneratorImplTest_Intrinsic, WorkgroupBarrier) {
693 Func("main", {}, ty.void_(), {CallStmt(Call("workgroupBarrier"))},
694 {
695 Stage(ast::PipelineStage::kCompute),
696 WorkgroupSize(1),
697 });
698
699 GeneratorImpl& gen = Build();
700
701 ASSERT_TRUE(gen.Generate()) << gen.error();
702 EXPECT_EQ(gen.result(), R"([numthreads(1, 1, 1)]
703 void main() {
704 GroupMemoryBarrierWithGroupSync();
705 return;
706 }
707 )");
708 }
709
TEST_F(HlslGeneratorImplTest_Intrinsic,Ignore)710 TEST_F(HlslGeneratorImplTest_Intrinsic, Ignore) {
711 Func("f", {Param("a", ty.i32()), Param("b", ty.i32()), Param("c", ty.i32())},
712 ty.i32(), {Return(Mul(Add("a", "b"), "c"))});
713
714 Func("main", {}, ty.void_(), {CallStmt(Call("ignore", Call("f", 1, 2, 3)))},
715 {
716 Stage(ast::PipelineStage::kCompute),
717 WorkgroupSize(1),
718 });
719
720 GeneratorImpl& gen = Build();
721
722 ASSERT_TRUE(gen.Generate()) << gen.error();
723 EXPECT_EQ(gen.result(), R"(int f(int a, int b, int c) {
724 return ((a + b) * c);
725 }
726
727 [numthreads(1, 1, 1)]
728 void main() {
729 f(1, 2, 3);
730 return;
731 }
732 )");
733 }
734
735 } // namespace
736 } // namespace hlsl
737 } // namespace writer
738 } // namespace tint
739