• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/pad_array_elements.h"
16 
17 #include <unordered_map>
18 #include <utility>
19 
20 #include "src/program_builder.h"
21 #include "src/sem/array.h"
22 #include "src/sem/call.h"
23 #include "src/sem/expression.h"
24 #include "src/sem/type_constructor.h"
25 #include "src/utils/map.h"
26 
27 TINT_INSTANTIATE_TYPEINFO(tint::transform::PadArrayElements);
28 
29 namespace tint {
30 namespace transform {
31 namespace {
32 
33 using ArrayBuilder = std::function<const ast::Array*()>;
34 
35 /// PadArray returns a function that constructs a new array in `ctx.dst` with
36 /// the element type padded to account for the explicit stride. PadArray will
37 /// recursively pad arrays-of-arrays. The new array element type will be added
38 /// to module-scope type declarations of `ctx.dst`.
39 /// @param ctx the CloneContext
40 /// @param create_ast_type_for Transform::CreateASTTypeFor()
41 /// @param padded_arrays a map of src array type to the new array name
42 /// @param array the array type
43 /// @return the new AST array
44 template <typename CREATE_AST_TYPE_FOR>
PadArray(CloneContext & ctx,CREATE_AST_TYPE_FOR && create_ast_type_for,std::unordered_map<const sem::Array *,ArrayBuilder> & padded_arrays,const sem::Array * array)45 ArrayBuilder PadArray(
46     CloneContext& ctx,
47     CREATE_AST_TYPE_FOR&& create_ast_type_for,
48     std::unordered_map<const sem::Array*, ArrayBuilder>& padded_arrays,
49     const sem::Array* array) {
50   if (array->IsStrideImplicit()) {
51     // We don't want to wrap arrays that have an implicit stride
52     return nullptr;
53   }
54 
55   return utils::GetOrCreate(padded_arrays, array, [&] {
56     // Generate a unique name for the array element type
57     auto name = ctx.dst->Symbols().New("tint_padded_array_element");
58 
59     // Examine the element type. Is it also an array?
60     const ast::Type* el_ty = nullptr;
61     if (auto* el_array = array->ElemType()->As<sem::Array>()) {
62       // Array of array - call PadArray() on the element type
63       if (auto p =
64               PadArray(ctx, create_ast_type_for, padded_arrays, el_array)) {
65         el_ty = p();
66       }
67     }
68 
69     // If the element wasn't a padded array, just create the typical AST type
70     // for it
71     if (el_ty == nullptr) {
72       el_ty = create_ast_type_for(ctx, array->ElemType());
73     }
74 
75     // Structure() will create and append the ast::Struct to the
76     // global declarations of `ctx.dst`. As we haven't finished building the
77     // current module-scope statement or function, this will be placed
78     // immediately before the usage.
79     ctx.dst->Structure(
80         name,
81         {ctx.dst->Member("el", el_ty, {ctx.dst->MemberSize(array->Stride())})});
82 
83     auto* dst = ctx.dst;
84     return [=] {
85       if (array->IsRuntimeSized()) {
86         return dst->ty.array(dst->create<ast::TypeName>(name));
87       } else {
88         return dst->ty.array(dst->create<ast::TypeName>(name), array->Count());
89       }
90     };
91   });
92 }
93 
94 }  // namespace
95 
96 PadArrayElements::PadArrayElements() = default;
97 
98 PadArrayElements::~PadArrayElements() = default;
99 
Run(CloneContext & ctx,const DataMap &,DataMap &)100 void PadArrayElements::Run(CloneContext& ctx, const DataMap&, DataMap&) {
101   auto& sem = ctx.src->Sem();
102 
103   std::unordered_map<const sem::Array*, ArrayBuilder> padded_arrays;
104   auto pad = [&](const sem::Array* array) {
105     return PadArray(ctx, CreateASTTypeFor, padded_arrays, array);
106   };
107 
108   // Replace all array types with their corresponding padded array type
109   ctx.ReplaceAll([&](const ast::Type* ast_type) -> const ast::Type* {
110     auto* type = ctx.src->TypeOf(ast_type);
111     if (auto* array = type->UnwrapRef()->As<sem::Array>()) {
112       if (auto p = pad(array)) {
113         return p();
114       }
115     }
116     return nullptr;
117   });
118 
119   // Fix up index accessors so `a[1]` becomes `a[1].el`
120   ctx.ReplaceAll([&](const ast::IndexAccessorExpression* accessor)
121                      -> const ast::Expression* {
122     if (auto* array = tint::As<sem::Array>(
123             sem.Get(accessor->object)->Type()->UnwrapRef())) {
124       if (pad(array)) {
125         // Array element is wrapped in a structure. Emit a member accessor
126         // to get to the actual array element.
127         auto* idx = ctx.CloneWithoutTransform(accessor);
128         return ctx.dst->MemberAccessor(idx, "el");
129       }
130     }
131     return nullptr;
132   });
133 
134   // Fix up array constructors so `A(1,2)` becomes
135   // `A(padded(1), padded(2))`
136   ctx.ReplaceAll(
137       [&](const ast::CallExpression* expr) -> const ast::Expression* {
138         auto* call = sem.Get(expr);
139         if (auto* ctor = call->Target()->As<sem::TypeConstructor>()) {
140           if (auto* array = ctor->ReturnType()->As<sem::Array>()) {
141             if (auto p = pad(array)) {
142               auto* arr_ty = p();
143               auto el_typename = arr_ty->type->As<ast::TypeName>()->name;
144 
145               ast::ExpressionList args;
146               args.reserve(call->Arguments().size());
147               for (auto* arg : call->Arguments()) {
148                 auto* val = ctx.Clone(arg->Declaration());
149                 args.emplace_back(ctx.dst->Construct(
150                     ctx.dst->create<ast::TypeName>(el_typename), val));
151               }
152 
153               return ctx.dst->Construct(arr_ty, args);
154             }
155           }
156         }
157         return nullptr;
158       });
159 
160   ctx.Clone();
161 }
162 
163 }  // namespace transform
164 }  // namespace tint
165