• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/transform/vertex_pulling.h"
16 
17 #include <algorithm>
18 #include <utility>
19 
20 #include "src/ast/assignment_statement.h"
21 #include "src/ast/bitcast_expression.h"
22 #include "src/ast/struct_block_decoration.h"
23 #include "src/ast/variable_decl_statement.h"
24 #include "src/program_builder.h"
25 #include "src/sem/variable.h"
26 #include "src/utils/map.h"
27 #include "src/utils/math.h"
28 
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling);
30 TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
31 
32 namespace tint {
33 namespace transform {
34 
35 namespace {
36 
37 /// The base type of a component.
38 /// The format type is either this type or a vector of this type.
39 enum class BaseType {
40   kInvalid,
41   kU32,
42   kI32,
43   kF32,
44 };
45 
46 /// Writes the BaseType to the std::ostream.
47 /// @param out the std::ostream to write to
48 /// @param format the BaseType to write
49 /// @returns out so calls can be chained
operator <<(std::ostream & out,BaseType format)50 std::ostream& operator<<(std::ostream& out, BaseType format) {
51   switch (format) {
52     case BaseType::kInvalid:
53       return out << "invalid";
54     case BaseType::kU32:
55       return out << "u32";
56     case BaseType::kI32:
57       return out << "i32";
58     case BaseType::kF32:
59       return out << "f32";
60   }
61   return out << "<unknown>";
62 }
63 
64 /// Writes the VertexFormat to the std::ostream.
65 /// @param out the std::ostream to write to
66 /// @param format the VertexFormat to write
67 /// @returns out so calls can be chained
operator <<(std::ostream & out,VertexFormat format)68 std::ostream& operator<<(std::ostream& out, VertexFormat format) {
69   switch (format) {
70     case VertexFormat::kUint8x2:
71       return out << "uint8x2";
72     case VertexFormat::kUint8x4:
73       return out << "uint8x4";
74     case VertexFormat::kSint8x2:
75       return out << "sint8x2";
76     case VertexFormat::kSint8x4:
77       return out << "sint8x4";
78     case VertexFormat::kUnorm8x2:
79       return out << "unorm8x2";
80     case VertexFormat::kUnorm8x4:
81       return out << "unorm8x4";
82     case VertexFormat::kSnorm8x2:
83       return out << "snorm8x2";
84     case VertexFormat::kSnorm8x4:
85       return out << "snorm8x4";
86     case VertexFormat::kUint16x2:
87       return out << "uint16x2";
88     case VertexFormat::kUint16x4:
89       return out << "uint16x4";
90     case VertexFormat::kSint16x2:
91       return out << "sint16x2";
92     case VertexFormat::kSint16x4:
93       return out << "sint16x4";
94     case VertexFormat::kUnorm16x2:
95       return out << "unorm16x2";
96     case VertexFormat::kUnorm16x4:
97       return out << "unorm16x4";
98     case VertexFormat::kSnorm16x2:
99       return out << "snorm16x2";
100     case VertexFormat::kSnorm16x4:
101       return out << "snorm16x4";
102     case VertexFormat::kFloat16x2:
103       return out << "float16x2";
104     case VertexFormat::kFloat16x4:
105       return out << "float16x4";
106     case VertexFormat::kFloat32:
107       return out << "float32";
108     case VertexFormat::kFloat32x2:
109       return out << "float32x2";
110     case VertexFormat::kFloat32x3:
111       return out << "float32x3";
112     case VertexFormat::kFloat32x4:
113       return out << "float32x4";
114     case VertexFormat::kUint32:
115       return out << "uint32";
116     case VertexFormat::kUint32x2:
117       return out << "uint32x2";
118     case VertexFormat::kUint32x3:
119       return out << "uint32x3";
120     case VertexFormat::kUint32x4:
121       return out << "uint32x4";
122     case VertexFormat::kSint32:
123       return out << "sint32";
124     case VertexFormat::kSint32x2:
125       return out << "sint32x2";
126     case VertexFormat::kSint32x3:
127       return out << "sint32x3";
128     case VertexFormat::kSint32x4:
129       return out << "sint32x4";
130   }
131   return out << "<unknown>";
132 }
133 
134 /// A vertex attribute data format.
135 struct DataType {
136   BaseType base_type;
137   uint32_t width;  // 1 for scalar, 2+ for a vector
138 };
139 
DataTypeOf(const sem::Type * ty)140 DataType DataTypeOf(const sem::Type* ty) {
141   if (ty->Is<sem::I32>()) {
142     return {BaseType::kI32, 1};
143   }
144   if (ty->Is<sem::U32>()) {
145     return {BaseType::kU32, 1};
146   }
147   if (ty->Is<sem::F32>()) {
148     return {BaseType::kF32, 1};
149   }
150   if (auto* vec = ty->As<sem::Vector>()) {
151     return {DataTypeOf(vec->type()).base_type, vec->Width()};
152   }
153   return {BaseType::kInvalid, 0};
154 }
155 
DataTypeOf(VertexFormat format)156 DataType DataTypeOf(VertexFormat format) {
157   switch (format) {
158     case VertexFormat::kUint32:
159       return {BaseType::kU32, 1};
160     case VertexFormat::kUint8x2:
161     case VertexFormat::kUint16x2:
162     case VertexFormat::kUint32x2:
163       return {BaseType::kU32, 2};
164     case VertexFormat::kUint32x3:
165       return {BaseType::kU32, 3};
166     case VertexFormat::kUint8x4:
167     case VertexFormat::kUint16x4:
168     case VertexFormat::kUint32x4:
169       return {BaseType::kU32, 4};
170     case VertexFormat::kSint32:
171       return {BaseType::kI32, 1};
172     case VertexFormat::kSint8x2:
173     case VertexFormat::kSint16x2:
174     case VertexFormat::kSint32x2:
175       return {BaseType::kI32, 2};
176     case VertexFormat::kSint32x3:
177       return {BaseType::kI32, 3};
178     case VertexFormat::kSint8x4:
179     case VertexFormat::kSint16x4:
180     case VertexFormat::kSint32x4:
181       return {BaseType::kI32, 4};
182     case VertexFormat::kFloat32:
183       return {BaseType::kF32, 1};
184     case VertexFormat::kUnorm8x2:
185     case VertexFormat::kSnorm8x2:
186     case VertexFormat::kUnorm16x2:
187     case VertexFormat::kSnorm16x2:
188     case VertexFormat::kFloat16x2:
189     case VertexFormat::kFloat32x2:
190       return {BaseType::kF32, 2};
191     case VertexFormat::kFloat32x3:
192       return {BaseType::kF32, 3};
193     case VertexFormat::kUnorm8x4:
194     case VertexFormat::kSnorm8x4:
195     case VertexFormat::kUnorm16x4:
196     case VertexFormat::kSnorm16x4:
197     case VertexFormat::kFloat16x4:
198     case VertexFormat::kFloat32x4:
199       return {BaseType::kF32, 4};
200   }
201   return {BaseType::kInvalid, 0};
202 }
203 
204 struct State {
Statetint::transform::__anon572942530111::State205   State(CloneContext& context, const VertexPulling::Config& c)
206       : ctx(context), cfg(c) {}
207   State(const State&) = default;
208   ~State() = default;
209 
210   /// LocationReplacement describes an ast::Variable replacement for a
211   /// location input.
212   struct LocationReplacement {
213     /// The variable to replace in the source Program
214     ast::Variable* from;
215     /// The replacement to use in the target ProgramBuilder
216     ast::Variable* to;
217   };
218 
219   struct LocationInfo {
220     std::function<const ast::Expression*()> expr;
221     const sem::Type* type;
222   };
223 
224   CloneContext& ctx;
225   VertexPulling::Config const cfg;
226   std::unordered_map<uint32_t, LocationInfo> location_info;
227   std::function<const ast::Expression*()> vertex_index_expr = nullptr;
228   std::function<const ast::Expression*()> instance_index_expr = nullptr;
229   Symbol pulling_position_name;
230   Symbol struct_buffer_name;
231   std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
232   ast::VariableList new_function_parameters;
233 
234   /// Generate the vertex buffer binding name
235   /// @param index index to append to buffer name
GetVertexBufferNametint::transform::__anon572942530111::State236   Symbol GetVertexBufferName(uint32_t index) {
237     return utils::GetOrCreate(vertex_buffer_names, index, [&] {
238       static const char kVertexBufferNamePrefix[] =
239           "tint_pulling_vertex_buffer_";
240       return ctx.dst->Symbols().New(kVertexBufferNamePrefix +
241                                     std::to_string(index));
242     });
243   }
244 
245   /// Lazily generates the structure buffer symbol
GetStructBufferNametint::transform::__anon572942530111::State246   Symbol GetStructBufferName() {
247     if (!struct_buffer_name.IsValid()) {
248       static const char kStructBufferName[] = "tint_vertex_data";
249       struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
250     }
251     return struct_buffer_name;
252   }
253 
254   /// Adds storage buffer decorated variables for the vertex buffers
AddVertexStorageBufferstint::transform::__anon572942530111::State255   void AddVertexStorageBuffers() {
256     // Creating the struct type
257     static const char kStructName[] = "TintVertexData";
258     auto* struct_type = ctx.dst->Structure(
259         ctx.dst->Symbols().New(kStructName),
260         {
261             ctx.dst->Member(GetStructBufferName(),
262                             ctx.dst->ty.array<ProgramBuilder::u32>(4)),
263         },
264         {
265             ctx.dst->create<ast::StructBlockDecoration>(),
266         });
267     for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
268       // The decorated variable with struct type
269       ctx.dst->Global(
270           GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
271           ast::StorageClass::kStorage, ast::Access::kRead,
272           ast::DecorationList{
273               ctx.dst->create<ast::BindingDecoration>(i),
274               ctx.dst->create<ast::GroupDecoration>(cfg.pulling_group),
275           });
276     }
277   }
278 
279   /// Creates and returns the assignment to the variables from the buffers
CreateVertexPullingPreambletint::transform::__anon572942530111::State280   ast::BlockStatement* CreateVertexPullingPreamble() {
281     // Assign by looking at the vertex descriptor to find attributes with
282     // matching location.
283 
284     ast::StatementList stmts;
285 
286     for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size();
287          ++buffer_idx) {
288       const VertexBufferLayoutDescriptor& buffer_layout =
289           cfg.vertex_state[buffer_idx];
290 
291       if ((buffer_layout.array_stride & 3) != 0) {
292         ctx.dst->Diagnostics().add_error(
293             diag::System::Transform,
294             "WebGPU requires that vertex stride must be a multiple of 4 bytes, "
295             "but VertexPulling array stride for buffer " +
296                 std::to_string(buffer_idx) + " was " +
297                 std::to_string(buffer_layout.array_stride) + " bytes");
298         return nullptr;
299       }
300 
301       auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
302                              ? vertex_index_expr()
303                              : instance_index_expr();
304 
305       // buffer_array_base is the base array offset for all the vertex
306       // attributes. These are units of uint (4 bytes).
307       auto buffer_array_base = ctx.dst->Symbols().New(
308           "buffer_array_base_" + std::to_string(buffer_idx));
309 
310       auto* attribute_offset = index_expr;
311       if (buffer_layout.array_stride != 4) {
312         attribute_offset =
313             ctx.dst->Mul(index_expr, buffer_layout.array_stride / 4u);
314       }
315 
316       // let pulling_offset_n = <attribute_offset>
317       stmts.emplace_back(ctx.dst->Decl(
318           ctx.dst->Const(buffer_array_base, nullptr, attribute_offset)));
319 
320       for (const VertexAttributeDescriptor& attribute_desc :
321            buffer_layout.attributes) {
322         auto it = location_info.find(attribute_desc.shader_location);
323         if (it == location_info.end()) {
324           continue;
325         }
326         auto& var = it->second;
327 
328         // Data type of the target WGSL variable
329         auto var_dt = DataTypeOf(var.type);
330         // Data type of the vertex stream attribute
331         auto fmt_dt = DataTypeOf(attribute_desc.format);
332 
333         // Base types must match between the vertex stream and the WGSL variable
334         if (var_dt.base_type != fmt_dt.base_type) {
335           std::stringstream err;
336           err << "VertexAttributeDescriptor for location "
337               << std::to_string(attribute_desc.shader_location)
338               << " has format " << attribute_desc.format
339               << " but shader expects "
340               << var.type->FriendlyName(ctx.src->Symbols());
341           ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
342           return nullptr;
343         }
344 
345         // Load the attribute value
346         auto* fetch = Fetch(buffer_array_base, attribute_desc.offset,
347                             buffer_idx, attribute_desc.format);
348 
349         // The attribute value may not be of the desired vector width. If it is
350         // not, we'll need to either reduce the width with a swizzle, or append
351         // 0's and / or a 1.
352         auto* value = fetch;
353         if (var_dt.width < fmt_dt.width) {
354           // WGSL variable vector width is smaller than the loaded vector width
355           switch (var_dt.width) {
356             case 1:
357               value = ctx.dst->MemberAccessor(fetch, "x");
358               break;
359             case 2:
360               value = ctx.dst->MemberAccessor(fetch, "xy");
361               break;
362             case 3:
363               value = ctx.dst->MemberAccessor(fetch, "xyz");
364               break;
365             default:
366               TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
367                   << var_dt.width;
368               return nullptr;
369           }
370         } else if (var_dt.width > fmt_dt.width) {
371           // WGSL variable vector width is wider than the loaded vector width
372           const ast::Type* ty = nullptr;
373           ast::ExpressionList values{fetch};
374           switch (var_dt.base_type) {
375             case BaseType::kI32:
376               ty = ctx.dst->ty.i32();
377               for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
378                 values.emplace_back(ctx.dst->Expr((i == 3) ? 1 : 0));
379               }
380               break;
381             case BaseType::kU32:
382               ty = ctx.dst->ty.u32();
383               for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
384                 values.emplace_back(ctx.dst->Expr((i == 3) ? 1u : 0u));
385               }
386               break;
387             case BaseType::kF32:
388               ty = ctx.dst->ty.f32();
389               for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
390                 values.emplace_back(ctx.dst->Expr((i == 3) ? 1.f : 0.f));
391               }
392               break;
393             default:
394               TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
395                   << var_dt.base_type;
396               return nullptr;
397           }
398           value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
399         }
400 
401         // Assign the value to the WGSL variable
402         stmts.emplace_back(ctx.dst->Assign(var.expr(), value));
403       }
404     }
405 
406     if (stmts.empty()) {
407       return nullptr;
408     }
409 
410     return ctx.dst->create<ast::BlockStatement>(stmts);
411   }
412 
413   /// Generates an expression reading from a buffer a specific format.
414   /// @param array_base the symbol of the variable holding the base array offset
415   /// of the vertex array (each index is 4-bytes).
416   /// @param offset the byte offset of the data from `buffer_base`
417   /// @param buffer the index of the vertex buffer
418   /// @param format the format to read
Fetchtint::transform::__anon572942530111::State419   const ast::Expression* Fetch(Symbol array_base,
420                                uint32_t offset,
421                                uint32_t buffer,
422                                VertexFormat format) {
423     using u32 = ProgramBuilder::u32;
424     using i32 = ProgramBuilder::i32;
425     using f32 = ProgramBuilder::f32;
426 
427     // Returns a u32 loaded from buffer_base + offset.
428     auto load_u32 = [&] {
429       return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
430     };
431 
432     // Returns a i32 loaded from buffer_base + offset.
433     auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
434 
435     // Returns a u32 loaded from buffer_base + offset + 4.
436     auto load_next_u32 = [&] {
437       return LoadPrimitive(array_base, offset + 4, buffer,
438                            VertexFormat::kUint32);
439     };
440 
441     // Returns a i32 loaded from buffer_base + offset + 4.
442     auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
443 
444     // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
445     // The low 16 bits are 0.
446     // `min_alignment` must be a power of two.
447     // `offset` must be `min_alignment` bytes aligned.
448     auto load_u16_h = [&] {
449       auto low_u32_offset = offset & ~3u;
450       auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
451                                     VertexFormat::kUint32);
452       switch (offset & 3) {
453         case 0:
454           return ctx.dst->Shl(low_u32, 16u);
455         case 1:
456           return ctx.dst->And(ctx.dst->Shl(low_u32, 8u), 0xffff0000u);
457         case 2:
458           return ctx.dst->And(low_u32, 0xffff0000u);
459         default: {  // 3:
460           auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
461                                          VertexFormat::kUint32);
462           auto* shr = ctx.dst->Shr(low_u32, 8u);
463           auto* shl = ctx.dst->Shl(high_u32, 24u);
464           return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000u);
465         }
466       }
467     };
468 
469     // Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
470     // The high 16 bits are 0.
471     auto load_u16_l = [&] {
472       auto low_u32_offset = offset & ~3u;
473       auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
474                                     VertexFormat::kUint32);
475       switch (offset & 3) {
476         case 0:
477           return ctx.dst->And(low_u32, 0xffffu);
478         case 1:
479           return ctx.dst->And(ctx.dst->Shr(low_u32, 8u), 0xffffu);
480         case 2:
481           return ctx.dst->Shr(low_u32, 16u);
482         default: {  // 3:
483           auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
484                                          VertexFormat::kUint32);
485           auto* shr = ctx.dst->Shr(low_u32, 24u);
486           auto* shl = ctx.dst->Shl(high_u32, 8u);
487           return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffffu);
488         }
489       }
490     };
491 
492     // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
493     // The low 16 bits are 0.
494     auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
495 
496     // Assumptions are made that alignment must be at least as large as the size
497     // of a single component.
498     switch (format) {
499       // Basic primitives
500       case VertexFormat::kUint32:
501       case VertexFormat::kSint32:
502       case VertexFormat::kFloat32:
503         return LoadPrimitive(array_base, offset, buffer, format);
504 
505         // Vectors of basic primitives
506       case VertexFormat::kUint32x2:
507         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
508                        VertexFormat::kUint32, 2);
509       case VertexFormat::kUint32x3:
510         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
511                        VertexFormat::kUint32, 3);
512       case VertexFormat::kUint32x4:
513         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
514                        VertexFormat::kUint32, 4);
515       case VertexFormat::kSint32x2:
516         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
517                        VertexFormat::kSint32, 2);
518       case VertexFormat::kSint32x3:
519         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
520                        VertexFormat::kSint32, 3);
521       case VertexFormat::kSint32x4:
522         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
523                        VertexFormat::kSint32, 4);
524       case VertexFormat::kFloat32x2:
525         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
526                        VertexFormat::kFloat32, 2);
527       case VertexFormat::kFloat32x3:
528         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
529                        VertexFormat::kFloat32, 3);
530       case VertexFormat::kFloat32x4:
531         return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
532                        VertexFormat::kFloat32, 4);
533 
534       case VertexFormat::kUint8x2: {
535         // yyxx0000, yyxx0000
536         auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
537         // xx000000, yyxx0000
538         auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8u, 0u));
539         // 000000xx, 000000yy
540         return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
541       }
542       case VertexFormat::kUint8x4: {
543         // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
544         auto* u32s = ctx.dst->vec4<u32>(load_u32());
545         // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
546         auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
547         // 000000xx, 000000yy, 000000zz, 000000ww
548         return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
549       }
550       case VertexFormat::kUint16x2: {
551         // yyyyxxxx, yyyyxxxx
552         auto* u32s = ctx.dst->vec2<u32>(load_u32());
553         // xxxx0000, yyyyxxxx
554         auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16u, 0u));
555         // 0000xxxx, 0000yyyy
556         return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
557       }
558       case VertexFormat::kUint16x4: {
559         // yyyyxxxx, wwwwzzzz
560         auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
561         // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
562         auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
563         // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
564         auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
565         // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
566         return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
567       }
568       case VertexFormat::kSint8x2: {
569         // yyxx0000, yyxx0000
570         auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
571         // xx000000, yyxx0000
572         auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8u, 0u));
573         // ssssssxx, ssssssyy
574         return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
575       }
576       case VertexFormat::kSint8x4: {
577         // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
578         auto* i32s = ctx.dst->vec4<i32>(load_i32());
579         // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
580         auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
581         // ssssssxx, ssssssyy, sssssszz, ssssssww
582         return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
583       }
584       case VertexFormat::kSint16x2: {
585         // yyyyxxxx, yyyyxxxx
586         auto* i32s = ctx.dst->vec2<i32>(load_i32());
587         // xxxx0000, yyyyxxxx
588         auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16u, 0u));
589         // ssssxxxx, ssssyyyy
590         return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
591       }
592       case VertexFormat::kSint16x4: {
593         // yyyyxxxx, wwwwzzzz
594         auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
595         // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
596         auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
597         // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
598         auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
599         // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
600         return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
601       }
602       case VertexFormat::kUnorm8x2:
603         return ctx.dst->MemberAccessor(
604             ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
605       case VertexFormat::kSnorm8x2:
606         return ctx.dst->MemberAccessor(
607             ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
608       case VertexFormat::kUnorm8x4:
609         return ctx.dst->Call("unpack4x8unorm", load_u32());
610       case VertexFormat::kSnorm8x4:
611         return ctx.dst->Call("unpack4x8snorm", load_u32());
612       case VertexFormat::kUnorm16x2:
613         return ctx.dst->Call("unpack2x16unorm", load_u32());
614       case VertexFormat::kSnorm16x2:
615         return ctx.dst->Call("unpack2x16snorm", load_u32());
616       case VertexFormat::kFloat16x2:
617         return ctx.dst->Call("unpack2x16float", load_u32());
618       case VertexFormat::kUnorm16x4:
619         return ctx.dst->vec4<f32>(
620             ctx.dst->Call("unpack2x16unorm", load_u32()),
621             ctx.dst->Call("unpack2x16unorm", load_next_u32()));
622       case VertexFormat::kSnorm16x4:
623         return ctx.dst->vec4<f32>(
624             ctx.dst->Call("unpack2x16snorm", load_u32()),
625             ctx.dst->Call("unpack2x16snorm", load_next_u32()));
626       case VertexFormat::kFloat16x4:
627         return ctx.dst->vec4<f32>(
628             ctx.dst->Call("unpack2x16float", load_u32()),
629             ctx.dst->Call("unpack2x16float", load_next_u32()));
630     }
631 
632     TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
633         << "format " << static_cast<int>(format);
634     return nullptr;
635   }
636 
637   /// Generates an expression reading an aligned basic type (u32, i32, f32) from
638   /// a vertex buffer.
639   /// @param array_base the symbol of the variable holding the base array offset
640   /// of the vertex array (each index is 4-bytes).
641   /// @param offset the byte offset of the data from `buffer_base`
642   /// @param buffer the index of the vertex buffer
643   /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
644   /// VertexFormat::kFloat32
LoadPrimitivetint::transform::__anon572942530111::State645   const ast::Expression* LoadPrimitive(Symbol array_base,
646                                        uint32_t offset,
647                                        uint32_t buffer,
648                                        VertexFormat format) {
649     const ast::Expression* u32 = nullptr;
650     if ((offset & 3) == 0) {
651       // Aligned load.
652 
653       const ast ::Expression* index = nullptr;
654       if (offset > 0) {
655         index = ctx.dst->Add(array_base, offset / 4);
656       } else {
657         index = ctx.dst->Expr(array_base);
658       }
659       u32 = ctx.dst->IndexAccessor(
660           ctx.dst->MemberAccessor(GetVertexBufferName(buffer),
661                                   GetStructBufferName()),
662           index);
663 
664     } else {
665       // Unaligned load
666       uint32_t offset_aligned = offset & ~3u;
667       auto* low = LoadPrimitive(array_base, offset_aligned, buffer,
668                                 VertexFormat::kUint32);
669       auto* high = LoadPrimitive(array_base, offset_aligned + 4u, buffer,
670                                  VertexFormat::kUint32);
671 
672       uint32_t shift = 8u * (offset & 3u);
673 
674       auto* low_shr = ctx.dst->Shr(low, shift);
675       auto* high_shl = ctx.dst->Shl(high, 32u - shift);
676       u32 = ctx.dst->Or(low_shr, high_shl);
677     }
678 
679     switch (format) {
680       case VertexFormat::kUint32:
681         return u32;
682       case VertexFormat::kSint32:
683         return ctx.dst->Bitcast(ctx.dst->ty.i32(), u32);
684       case VertexFormat::kFloat32:
685         return ctx.dst->Bitcast(ctx.dst->ty.f32(), u32);
686       default:
687         break;
688     }
689     TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
690         << "invalid format for LoadPrimitive" << static_cast<int>(format);
691     return nullptr;
692   }
693 
694   /// Generates an expression reading a vec2/3/4 from a vertex buffer.
695   /// @param array_base the symbol of the variable holding the base array offset
696   /// of the vertex array (each index is 4-bytes).
697   /// @param offset the byte offset of the data from `buffer_base`
698   /// @param buffer the index of the vertex buffer
699   /// @param element_stride stride between elements, in bytes
700   /// @param base_type underlying AST type
701   /// @param base_format underlying vertex format
702   /// @param count how many elements the vector has
LoadVectint::transform::__anon572942530111::State703   const ast::Expression* LoadVec(Symbol array_base,
704                                  uint32_t offset,
705                                  uint32_t buffer,
706                                  uint32_t element_stride,
707                                  const ast::Type* base_type,
708                                  VertexFormat base_format,
709                                  uint32_t count) {
710     ast::ExpressionList expr_list;
711     for (uint32_t i = 0; i < count; ++i) {
712       // Offset read position by element_stride for each component
713       uint32_t primitive_offset = offset + element_stride * i;
714       expr_list.push_back(
715           LoadPrimitive(array_base, primitive_offset, buffer, base_format));
716     }
717 
718     return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
719                               std::move(expr_list));
720   }
721 
722   /// Process a non-struct entry point parameter.
723   /// Generate function-scope variables for location parameters, and record
724   /// vertex_index and instance_index builtins if present.
725   /// @param func the entry point function
726   /// @param param the parameter to process
ProcessNonStructParametertint::transform::__anon572942530111::State727   void ProcessNonStructParameter(const ast::Function* func,
728                                  const ast::Variable* param) {
729     if (auto* location =
730             ast::GetDecoration<ast::LocationDecoration>(param->decorations)) {
731       // Create a function-scope variable to replace the parameter.
732       auto func_var_sym = ctx.Clone(param->symbol);
733       auto* func_var_type = ctx.Clone(param->type);
734       auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
735       ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
736       // Capture mapping from location to the new variable.
737       LocationInfo info;
738       info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
739       info.type = ctx.src->Sem().Get(param)->Type();
740       location_info[location->value] = info;
741     } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
742                    param->decorations)) {
743       // Check for existing vertex_index and instance_index builtins.
744       if (builtin->builtin == ast::Builtin::kVertexIndex) {
745         vertex_index_expr = [this, param]() {
746           return ctx.dst->Expr(ctx.Clone(param->symbol));
747         };
748       } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
749         instance_index_expr = [this, param]() {
750           return ctx.dst->Expr(ctx.Clone(param->symbol));
751         };
752       }
753       new_function_parameters.push_back(ctx.Clone(param));
754     } else {
755       TINT_ICE(Transform, ctx.dst->Diagnostics())
756           << "Invalid entry point parameter";
757     }
758   }
759 
760   /// Process a struct entry point parameter.
761   /// If the struct has members with location attributes, push the parameter to
762   /// a function-scope variable and create a new struct parameter without those
763   /// attributes. Record expressions for members that are vertex_index and
764   /// instance_index builtins.
765   /// @param func the entry point function
766   /// @param param the parameter to process
767   /// @param struct_ty the structure type
ProcessStructParametertint::transform::__anon572942530111::State768   void ProcessStructParameter(const ast::Function* func,
769                               const ast::Variable* param,
770                               const ast::Struct* struct_ty) {
771     auto param_sym = ctx.Clone(param->symbol);
772 
773     // Process the struct members.
774     bool has_locations = false;
775     ast::StructMemberList members_to_clone;
776     for (auto* member : struct_ty->members) {
777       auto member_sym = ctx.Clone(member->symbol);
778       std::function<const ast::Expression*()> member_expr = [this, param_sym,
779                                                              member_sym]() {
780         return ctx.dst->MemberAccessor(param_sym, member_sym);
781       };
782 
783       if (auto* location = ast::GetDecoration<ast::LocationDecoration>(
784               member->decorations)) {
785         // Capture mapping from location to struct member.
786         LocationInfo info;
787         info.expr = member_expr;
788         info.type = ctx.src->Sem().Get(member)->Type();
789         location_info[location->value] = info;
790         has_locations = true;
791       } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
792                      member->decorations)) {
793         // Check for existing vertex_index and instance_index builtins.
794         if (builtin->builtin == ast::Builtin::kVertexIndex) {
795           vertex_index_expr = member_expr;
796         } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
797           instance_index_expr = member_expr;
798         }
799         members_to_clone.push_back(member);
800       } else {
801         TINT_ICE(Transform, ctx.dst->Diagnostics())
802             << "Invalid entry point parameter";
803       }
804     }
805 
806     if (!has_locations) {
807       // Nothing to do.
808       new_function_parameters.push_back(ctx.Clone(param));
809       return;
810     }
811 
812     // Create a function-scope variable to replace the parameter.
813     auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
814     ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
815 
816     if (!members_to_clone.empty()) {
817       // Create a new struct without the location attributes.
818       ast::StructMemberList new_members;
819       for (auto* member : members_to_clone) {
820         auto member_sym = ctx.Clone(member->symbol);
821         auto* member_type = ctx.Clone(member->type);
822         auto member_decos = ctx.Clone(member->decorations);
823         new_members.push_back(
824             ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
825       }
826       auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
827 
828       // Create a new function parameter with this struct.
829       auto* new_param =
830           ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
831       new_function_parameters.push_back(new_param);
832 
833       // Copy values from the new parameter to the function-scope variable.
834       for (auto* member : members_to_clone) {
835         auto member_name = ctx.Clone(member->symbol);
836         ctx.InsertFront(
837             func->body->statements,
838             ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
839                             ctx.dst->MemberAccessor(new_param, member_name)));
840       }
841     }
842   }
843 
844   /// Process an entry point function.
845   /// @param func the entry point function
Processtint::transform::__anon572942530111::State846   void Process(const ast::Function* func) {
847     if (func->body->Empty()) {
848       return;
849     }
850 
851     // Process entry point parameters.
852     for (auto* param : func->params) {
853       auto* sem = ctx.src->Sem().Get(param);
854       if (auto* str = sem->Type()->As<sem::Struct>()) {
855         ProcessStructParameter(func, param, str->Declaration());
856       } else {
857         ProcessNonStructParameter(func, param);
858       }
859     }
860 
861     // Insert new parameters for vertex_index and instance_index if needed.
862     if (!vertex_index_expr) {
863       for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
864         if (layout.step_mode == VertexStepMode::kVertex) {
865           auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
866           new_function_parameters.push_back(
867               ctx.dst->Param(name, ctx.dst->ty.u32(),
868                              {ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
869           vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
870           break;
871         }
872       }
873     }
874     if (!instance_index_expr) {
875       for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
876         if (layout.step_mode == VertexStepMode::kInstance) {
877           auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
878           new_function_parameters.push_back(
879               ctx.dst->Param(name, ctx.dst->ty.u32(),
880                              {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
881           instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
882           break;
883         }
884       }
885     }
886 
887     // Generate vertex pulling preamble.
888     if (auto* block = CreateVertexPullingPreamble()) {
889       ctx.InsertFront(func->body->statements, block);
890     }
891 
892     // Rewrite the function header with the new parameters.
893     auto func_sym = ctx.Clone(func->symbol);
894     auto* ret_type = ctx.Clone(func->return_type);
895     auto* body = ctx.Clone(func->body);
896     auto decos = ctx.Clone(func->decorations);
897     auto ret_decos = ctx.Clone(func->return_type_decorations);
898     auto* new_func = ctx.dst->create<ast::Function>(
899         func->source, func_sym, new_function_parameters, ret_type, body,
900         std::move(decos), std::move(ret_decos));
901     ctx.Replace(func, new_func);
902   }
903 };
904 
905 }  // namespace
906 
907 VertexPulling::VertexPulling() = default;
908 VertexPulling::~VertexPulling() = default;
909 
Run(CloneContext & ctx,const DataMap & inputs,DataMap &)910 void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
911   auto cfg = cfg_;
912   if (auto* cfg_data = inputs.Get<Config>()) {
913     cfg = *cfg_data;
914   }
915 
916   // Find entry point
917   auto* func = ctx.src->AST().Functions().Find(
918       ctx.src->Symbols().Get(cfg.entry_point_name),
919       ast::PipelineStage::kVertex);
920   if (func == nullptr) {
921     ctx.dst->Diagnostics().add_error(diag::System::Transform,
922                                      "Vertex stage entry point not found");
923     return;
924   }
925 
926   // TODO(idanr): Need to check shader locations in descriptor cover all
927   // attributes
928 
929   // TODO(idanr): Make sure we covered all error cases, to guarantee the
930   // following stages will pass
931 
932   State state{ctx, cfg};
933   state.AddVertexStorageBuffers();
934   state.Process(func);
935 
936   ctx.Clone();
937 }
938 
939 VertexPulling::Config::Config() = default;
940 VertexPulling::Config::Config(const Config&) = default;
941 VertexPulling::Config::~Config() = default;
942 VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
943     default;
944 
945 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
946 
VertexBufferLayoutDescriptor(uint32_t in_array_stride,VertexStepMode in_step_mode,std::vector<VertexAttributeDescriptor> in_attributes)947 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
948     uint32_t in_array_stride,
949     VertexStepMode in_step_mode,
950     std::vector<VertexAttributeDescriptor> in_attributes)
951     : array_stride(in_array_stride),
952       step_mode(in_step_mode),
953       attributes(std::move(in_attributes)) {}
954 
955 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
956     const VertexBufferLayoutDescriptor& other) = default;
957 
958 VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
959     const VertexBufferLayoutDescriptor& other) = default;
960 
961 VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
962 
963 }  // namespace transform
964 }  // namespace tint
965