1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/vectorize_scalar_matrix_constructors.h"
16
17 #include <string>
18 #include <utility>
19
20 #include "src/transform/test_helper.h"
21 #include "src/utils/string.h"
22
23 namespace tint {
24 namespace transform {
25 namespace {
26
27 using VectorizeScalarMatrixConstructorsTest =
28 TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
29
TEST_P(VectorizeScalarMatrixConstructorsTest,Basic)30 TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
31 uint32_t cols = GetParam().first;
32 uint32_t rows = GetParam().second;
33 std::string mat_type =
34 "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
35 std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
36 std::string scalar_values;
37 std::string vector_values;
38 for (uint32_t c = 0; c < cols; c++) {
39 if (c > 0) {
40 vector_values += ", ";
41 scalar_values += ", ";
42 }
43 vector_values += vec_type + "(";
44 for (uint32_t r = 0; r < rows; r++) {
45 if (r > 0) {
46 scalar_values += ", ";
47 vector_values += ", ";
48 }
49 auto value = std::to_string(c * rows + r) + ".0";
50 scalar_values += value;
51 vector_values += value;
52 }
53 vector_values += ")";
54 }
55
56 std::string tmpl = R"(
57 [[stage(fragment)]]
58 fn main() {
59 let m = ${matrix}(${values});
60 }
61 )";
62 tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
63 auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
64 auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
65
66 auto got = Run<VectorizeScalarMatrixConstructors>(src);
67
68 EXPECT_EQ(expect, str(got));
69 }
70
TEST_P(VectorizeScalarMatrixConstructorsTest,NonScalarConstructors)71 TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
72 uint32_t cols = GetParam().first;
73 uint32_t rows = GetParam().second;
74 std::string mat_type =
75 "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
76 std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
77 std::string columns;
78 for (uint32_t c = 0; c < cols; c++) {
79 if (c > 0) {
80 columns += ", ";
81 }
82 columns += vec_type + "()";
83 }
84
85 std::string tmpl = R"(
86 [[stage(fragment)]]
87 fn main() {
88 let m = ${matrix}(${columns});
89 }
90 )";
91 tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
92 auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
93 auto expect = src;
94
95 auto got = Run<VectorizeScalarMatrixConstructors>(src);
96
97 EXPECT_EQ(expect, str(got));
98 }
99
100 INSTANTIATE_TEST_SUITE_P(VectorizeScalarMatrixConstructorsTest,
101 VectorizeScalarMatrixConstructorsTest,
102 testing::Values(std::make_pair(2, 2),
103 std::make_pair(2, 3),
104 std::make_pair(2, 4),
105 std::make_pair(3, 2),
106 std::make_pair(3, 3),
107 std::make_pair(3, 4),
108 std::make_pair(4, 2),
109 std::make_pair(4, 3),
110 std::make_pair(4, 4)));
111
112 } // namespace
113 } // namespace transform
114 } // namespace tint
115