• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/vectorize_scalar_matrix_constructors.h"
16 
17 #include <string>
18 #include <utility>
19 
20 #include "src/transform/test_helper.h"
21 #include "src/utils/string.h"
22 
23 namespace tint {
24 namespace transform {
25 namespace {
26 
27 using VectorizeScalarMatrixConstructorsTest =
28     TransformTestWithParam<std::pair<uint32_t, uint32_t>>;
29 
TEST_P(VectorizeScalarMatrixConstructorsTest,Basic)30 TEST_P(VectorizeScalarMatrixConstructorsTest, Basic) {
31   uint32_t cols = GetParam().first;
32   uint32_t rows = GetParam().second;
33   std::string mat_type =
34       "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
35   std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
36   std::string scalar_values;
37   std::string vector_values;
38   for (uint32_t c = 0; c < cols; c++) {
39     if (c > 0) {
40       vector_values += ", ";
41       scalar_values += ", ";
42     }
43     vector_values += vec_type + "(";
44     for (uint32_t r = 0; r < rows; r++) {
45       if (r > 0) {
46         scalar_values += ", ";
47         vector_values += ", ";
48       }
49       auto value = std::to_string(c * rows + r) + ".0";
50       scalar_values += value;
51       vector_values += value;
52     }
53     vector_values += ")";
54   }
55 
56   std::string tmpl = R"(
57 [[stage(fragment)]]
58 fn main() {
59   let m = ${matrix}(${values});
60 }
61 )";
62   tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
63   auto src = utils::ReplaceAll(tmpl, "${values}", scalar_values);
64   auto expect = utils::ReplaceAll(tmpl, "${values}", vector_values);
65 
66   auto got = Run<VectorizeScalarMatrixConstructors>(src);
67 
68   EXPECT_EQ(expect, str(got));
69 }
70 
TEST_P(VectorizeScalarMatrixConstructorsTest,NonScalarConstructors)71 TEST_P(VectorizeScalarMatrixConstructorsTest, NonScalarConstructors) {
72   uint32_t cols = GetParam().first;
73   uint32_t rows = GetParam().second;
74   std::string mat_type =
75       "mat" + std::to_string(cols) + "x" + std::to_string(rows) + "<f32>";
76   std::string vec_type = "vec" + std::to_string(rows) + "<f32>";
77   std::string columns;
78   for (uint32_t c = 0; c < cols; c++) {
79     if (c > 0) {
80       columns += ", ";
81     }
82     columns += vec_type + "()";
83   }
84 
85   std::string tmpl = R"(
86 [[stage(fragment)]]
87 fn main() {
88   let m = ${matrix}(${columns});
89 }
90 )";
91   tmpl = utils::ReplaceAll(tmpl, "${matrix}", mat_type);
92   auto src = utils::ReplaceAll(tmpl, "${columns}", columns);
93   auto expect = src;
94 
95   auto got = Run<VectorizeScalarMatrixConstructors>(src);
96 
97   EXPECT_EQ(expect, str(got));
98 }
99 
100 INSTANTIATE_TEST_SUITE_P(VectorizeScalarMatrixConstructorsTest,
101                          VectorizeScalarMatrixConstructorsTest,
102                          testing::Values(std::make_pair(2, 2),
103                                          std::make_pair(2, 3),
104                                          std::make_pair(2, 4),
105                                          std::make_pair(3, 2),
106                                          std::make_pair(3, 3),
107                                          std::make_pair(3, 4),
108                                          std::make_pair(4, 2),
109                                          std::make_pair(4, 3),
110                                          std::make_pair(4, 4)));
111 
112 }  // namespace
113 }  // namespace transform
114 }  // namespace tint
115