1 // Copyright 2020 The Tint Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/transform/vertex_pulling.h"
16
17 #include <algorithm>
18 #include <utility>
19
20 #include "src/ast/assignment_statement.h"
21 #include "src/ast/bitcast_expression.h"
22 #include "src/ast/struct_block_decoration.h"
23 #include "src/ast/variable_decl_statement.h"
24 #include "src/program_builder.h"
25 #include "src/sem/variable.h"
26 #include "src/utils/map.h"
27 #include "src/utils/math.h"
28
29 TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling);
30 TINT_INSTANTIATE_TYPEINFO(tint::transform::VertexPulling::Config);
31
32 namespace tint {
33 namespace transform {
34
35 namespace {
36
37 /// The base type of a component.
38 /// The format type is either this type or a vector of this type.
39 enum class BaseType {
40 kInvalid,
41 kU32,
42 kI32,
43 kF32,
44 };
45
46 /// Writes the BaseType to the std::ostream.
47 /// @param out the std::ostream to write to
48 /// @param format the BaseType to write
49 /// @returns out so calls can be chained
operator <<(std::ostream & out,BaseType format)50 std::ostream& operator<<(std::ostream& out, BaseType format) {
51 switch (format) {
52 case BaseType::kInvalid:
53 return out << "invalid";
54 case BaseType::kU32:
55 return out << "u32";
56 case BaseType::kI32:
57 return out << "i32";
58 case BaseType::kF32:
59 return out << "f32";
60 }
61 return out << "<unknown>";
62 }
63
64 /// Writes the VertexFormat to the std::ostream.
65 /// @param out the std::ostream to write to
66 /// @param format the VertexFormat to write
67 /// @returns out so calls can be chained
operator <<(std::ostream & out,VertexFormat format)68 std::ostream& operator<<(std::ostream& out, VertexFormat format) {
69 switch (format) {
70 case VertexFormat::kUint8x2:
71 return out << "uint8x2";
72 case VertexFormat::kUint8x4:
73 return out << "uint8x4";
74 case VertexFormat::kSint8x2:
75 return out << "sint8x2";
76 case VertexFormat::kSint8x4:
77 return out << "sint8x4";
78 case VertexFormat::kUnorm8x2:
79 return out << "unorm8x2";
80 case VertexFormat::kUnorm8x4:
81 return out << "unorm8x4";
82 case VertexFormat::kSnorm8x2:
83 return out << "snorm8x2";
84 case VertexFormat::kSnorm8x4:
85 return out << "snorm8x4";
86 case VertexFormat::kUint16x2:
87 return out << "uint16x2";
88 case VertexFormat::kUint16x4:
89 return out << "uint16x4";
90 case VertexFormat::kSint16x2:
91 return out << "sint16x2";
92 case VertexFormat::kSint16x4:
93 return out << "sint16x4";
94 case VertexFormat::kUnorm16x2:
95 return out << "unorm16x2";
96 case VertexFormat::kUnorm16x4:
97 return out << "unorm16x4";
98 case VertexFormat::kSnorm16x2:
99 return out << "snorm16x2";
100 case VertexFormat::kSnorm16x4:
101 return out << "snorm16x4";
102 case VertexFormat::kFloat16x2:
103 return out << "float16x2";
104 case VertexFormat::kFloat16x4:
105 return out << "float16x4";
106 case VertexFormat::kFloat32:
107 return out << "float32";
108 case VertexFormat::kFloat32x2:
109 return out << "float32x2";
110 case VertexFormat::kFloat32x3:
111 return out << "float32x3";
112 case VertexFormat::kFloat32x4:
113 return out << "float32x4";
114 case VertexFormat::kUint32:
115 return out << "uint32";
116 case VertexFormat::kUint32x2:
117 return out << "uint32x2";
118 case VertexFormat::kUint32x3:
119 return out << "uint32x3";
120 case VertexFormat::kUint32x4:
121 return out << "uint32x4";
122 case VertexFormat::kSint32:
123 return out << "sint32";
124 case VertexFormat::kSint32x2:
125 return out << "sint32x2";
126 case VertexFormat::kSint32x3:
127 return out << "sint32x3";
128 case VertexFormat::kSint32x4:
129 return out << "sint32x4";
130 }
131 return out << "<unknown>";
132 }
133
134 /// A vertex attribute data format.
135 struct DataType {
136 BaseType base_type;
137 uint32_t width; // 1 for scalar, 2+ for a vector
138 };
139
DataTypeOf(const sem::Type * ty)140 DataType DataTypeOf(const sem::Type* ty) {
141 if (ty->Is<sem::I32>()) {
142 return {BaseType::kI32, 1};
143 }
144 if (ty->Is<sem::U32>()) {
145 return {BaseType::kU32, 1};
146 }
147 if (ty->Is<sem::F32>()) {
148 return {BaseType::kF32, 1};
149 }
150 if (auto* vec = ty->As<sem::Vector>()) {
151 return {DataTypeOf(vec->type()).base_type, vec->Width()};
152 }
153 return {BaseType::kInvalid, 0};
154 }
155
DataTypeOf(VertexFormat format)156 DataType DataTypeOf(VertexFormat format) {
157 switch (format) {
158 case VertexFormat::kUint32:
159 return {BaseType::kU32, 1};
160 case VertexFormat::kUint8x2:
161 case VertexFormat::kUint16x2:
162 case VertexFormat::kUint32x2:
163 return {BaseType::kU32, 2};
164 case VertexFormat::kUint32x3:
165 return {BaseType::kU32, 3};
166 case VertexFormat::kUint8x4:
167 case VertexFormat::kUint16x4:
168 case VertexFormat::kUint32x4:
169 return {BaseType::kU32, 4};
170 case VertexFormat::kSint32:
171 return {BaseType::kI32, 1};
172 case VertexFormat::kSint8x2:
173 case VertexFormat::kSint16x2:
174 case VertexFormat::kSint32x2:
175 return {BaseType::kI32, 2};
176 case VertexFormat::kSint32x3:
177 return {BaseType::kI32, 3};
178 case VertexFormat::kSint8x4:
179 case VertexFormat::kSint16x4:
180 case VertexFormat::kSint32x4:
181 return {BaseType::kI32, 4};
182 case VertexFormat::kFloat32:
183 return {BaseType::kF32, 1};
184 case VertexFormat::kUnorm8x2:
185 case VertexFormat::kSnorm8x2:
186 case VertexFormat::kUnorm16x2:
187 case VertexFormat::kSnorm16x2:
188 case VertexFormat::kFloat16x2:
189 case VertexFormat::kFloat32x2:
190 return {BaseType::kF32, 2};
191 case VertexFormat::kFloat32x3:
192 return {BaseType::kF32, 3};
193 case VertexFormat::kUnorm8x4:
194 case VertexFormat::kSnorm8x4:
195 case VertexFormat::kUnorm16x4:
196 case VertexFormat::kSnorm16x4:
197 case VertexFormat::kFloat16x4:
198 case VertexFormat::kFloat32x4:
199 return {BaseType::kF32, 4};
200 }
201 return {BaseType::kInvalid, 0};
202 }
203
204 struct State {
Statetint::transform::__anon572942530111::State205 State(CloneContext& context, const VertexPulling::Config& c)
206 : ctx(context), cfg(c) {}
207 State(const State&) = default;
208 ~State() = default;
209
210 /// LocationReplacement describes an ast::Variable replacement for a
211 /// location input.
212 struct LocationReplacement {
213 /// The variable to replace in the source Program
214 ast::Variable* from;
215 /// The replacement to use in the target ProgramBuilder
216 ast::Variable* to;
217 };
218
219 struct LocationInfo {
220 std::function<const ast::Expression*()> expr;
221 const sem::Type* type;
222 };
223
224 CloneContext& ctx;
225 VertexPulling::Config const cfg;
226 std::unordered_map<uint32_t, LocationInfo> location_info;
227 std::function<const ast::Expression*()> vertex_index_expr = nullptr;
228 std::function<const ast::Expression*()> instance_index_expr = nullptr;
229 Symbol pulling_position_name;
230 Symbol struct_buffer_name;
231 std::unordered_map<uint32_t, Symbol> vertex_buffer_names;
232 ast::VariableList new_function_parameters;
233
234 /// Generate the vertex buffer binding name
235 /// @param index index to append to buffer name
GetVertexBufferNametint::transform::__anon572942530111::State236 Symbol GetVertexBufferName(uint32_t index) {
237 return utils::GetOrCreate(vertex_buffer_names, index, [&] {
238 static const char kVertexBufferNamePrefix[] =
239 "tint_pulling_vertex_buffer_";
240 return ctx.dst->Symbols().New(kVertexBufferNamePrefix +
241 std::to_string(index));
242 });
243 }
244
245 /// Lazily generates the structure buffer symbol
GetStructBufferNametint::transform::__anon572942530111::State246 Symbol GetStructBufferName() {
247 if (!struct_buffer_name.IsValid()) {
248 static const char kStructBufferName[] = "tint_vertex_data";
249 struct_buffer_name = ctx.dst->Symbols().New(kStructBufferName);
250 }
251 return struct_buffer_name;
252 }
253
254 /// Adds storage buffer decorated variables for the vertex buffers
AddVertexStorageBufferstint::transform::__anon572942530111::State255 void AddVertexStorageBuffers() {
256 // Creating the struct type
257 static const char kStructName[] = "TintVertexData";
258 auto* struct_type = ctx.dst->Structure(
259 ctx.dst->Symbols().New(kStructName),
260 {
261 ctx.dst->Member(GetStructBufferName(),
262 ctx.dst->ty.array<ProgramBuilder::u32>(4)),
263 },
264 {
265 ctx.dst->create<ast::StructBlockDecoration>(),
266 });
267 for (uint32_t i = 0; i < cfg.vertex_state.size(); ++i) {
268 // The decorated variable with struct type
269 ctx.dst->Global(
270 GetVertexBufferName(i), ctx.dst->ty.Of(struct_type),
271 ast::StorageClass::kStorage, ast::Access::kRead,
272 ast::DecorationList{
273 ctx.dst->create<ast::BindingDecoration>(i),
274 ctx.dst->create<ast::GroupDecoration>(cfg.pulling_group),
275 });
276 }
277 }
278
279 /// Creates and returns the assignment to the variables from the buffers
CreateVertexPullingPreambletint::transform::__anon572942530111::State280 ast::BlockStatement* CreateVertexPullingPreamble() {
281 // Assign by looking at the vertex descriptor to find attributes with
282 // matching location.
283
284 ast::StatementList stmts;
285
286 for (uint32_t buffer_idx = 0; buffer_idx < cfg.vertex_state.size();
287 ++buffer_idx) {
288 const VertexBufferLayoutDescriptor& buffer_layout =
289 cfg.vertex_state[buffer_idx];
290
291 if ((buffer_layout.array_stride & 3) != 0) {
292 ctx.dst->Diagnostics().add_error(
293 diag::System::Transform,
294 "WebGPU requires that vertex stride must be a multiple of 4 bytes, "
295 "but VertexPulling array stride for buffer " +
296 std::to_string(buffer_idx) + " was " +
297 std::to_string(buffer_layout.array_stride) + " bytes");
298 return nullptr;
299 }
300
301 auto* index_expr = buffer_layout.step_mode == VertexStepMode::kVertex
302 ? vertex_index_expr()
303 : instance_index_expr();
304
305 // buffer_array_base is the base array offset for all the vertex
306 // attributes. These are units of uint (4 bytes).
307 auto buffer_array_base = ctx.dst->Symbols().New(
308 "buffer_array_base_" + std::to_string(buffer_idx));
309
310 auto* attribute_offset = index_expr;
311 if (buffer_layout.array_stride != 4) {
312 attribute_offset =
313 ctx.dst->Mul(index_expr, buffer_layout.array_stride / 4u);
314 }
315
316 // let pulling_offset_n = <attribute_offset>
317 stmts.emplace_back(ctx.dst->Decl(
318 ctx.dst->Const(buffer_array_base, nullptr, attribute_offset)));
319
320 for (const VertexAttributeDescriptor& attribute_desc :
321 buffer_layout.attributes) {
322 auto it = location_info.find(attribute_desc.shader_location);
323 if (it == location_info.end()) {
324 continue;
325 }
326 auto& var = it->second;
327
328 // Data type of the target WGSL variable
329 auto var_dt = DataTypeOf(var.type);
330 // Data type of the vertex stream attribute
331 auto fmt_dt = DataTypeOf(attribute_desc.format);
332
333 // Base types must match between the vertex stream and the WGSL variable
334 if (var_dt.base_type != fmt_dt.base_type) {
335 std::stringstream err;
336 err << "VertexAttributeDescriptor for location "
337 << std::to_string(attribute_desc.shader_location)
338 << " has format " << attribute_desc.format
339 << " but shader expects "
340 << var.type->FriendlyName(ctx.src->Symbols());
341 ctx.dst->Diagnostics().add_error(diag::System::Transform, err.str());
342 return nullptr;
343 }
344
345 // Load the attribute value
346 auto* fetch = Fetch(buffer_array_base, attribute_desc.offset,
347 buffer_idx, attribute_desc.format);
348
349 // The attribute value may not be of the desired vector width. If it is
350 // not, we'll need to either reduce the width with a swizzle, or append
351 // 0's and / or a 1.
352 auto* value = fetch;
353 if (var_dt.width < fmt_dt.width) {
354 // WGSL variable vector width is smaller than the loaded vector width
355 switch (var_dt.width) {
356 case 1:
357 value = ctx.dst->MemberAccessor(fetch, "x");
358 break;
359 case 2:
360 value = ctx.dst->MemberAccessor(fetch, "xy");
361 break;
362 case 3:
363 value = ctx.dst->MemberAccessor(fetch, "xyz");
364 break;
365 default:
366 TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
367 << var_dt.width;
368 return nullptr;
369 }
370 } else if (var_dt.width > fmt_dt.width) {
371 // WGSL variable vector width is wider than the loaded vector width
372 const ast::Type* ty = nullptr;
373 ast::ExpressionList values{fetch};
374 switch (var_dt.base_type) {
375 case BaseType::kI32:
376 ty = ctx.dst->ty.i32();
377 for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
378 values.emplace_back(ctx.dst->Expr((i == 3) ? 1 : 0));
379 }
380 break;
381 case BaseType::kU32:
382 ty = ctx.dst->ty.u32();
383 for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
384 values.emplace_back(ctx.dst->Expr((i == 3) ? 1u : 0u));
385 }
386 break;
387 case BaseType::kF32:
388 ty = ctx.dst->ty.f32();
389 for (uint32_t i = fmt_dt.width; i < var_dt.width; i++) {
390 values.emplace_back(ctx.dst->Expr((i == 3) ? 1.f : 0.f));
391 }
392 break;
393 default:
394 TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
395 << var_dt.base_type;
396 return nullptr;
397 }
398 value = ctx.dst->Construct(ctx.dst->ty.vec(ty, var_dt.width), values);
399 }
400
401 // Assign the value to the WGSL variable
402 stmts.emplace_back(ctx.dst->Assign(var.expr(), value));
403 }
404 }
405
406 if (stmts.empty()) {
407 return nullptr;
408 }
409
410 return ctx.dst->create<ast::BlockStatement>(stmts);
411 }
412
413 /// Generates an expression reading from a buffer a specific format.
414 /// @param array_base the symbol of the variable holding the base array offset
415 /// of the vertex array (each index is 4-bytes).
416 /// @param offset the byte offset of the data from `buffer_base`
417 /// @param buffer the index of the vertex buffer
418 /// @param format the format to read
Fetchtint::transform::__anon572942530111::State419 const ast::Expression* Fetch(Symbol array_base,
420 uint32_t offset,
421 uint32_t buffer,
422 VertexFormat format) {
423 using u32 = ProgramBuilder::u32;
424 using i32 = ProgramBuilder::i32;
425 using f32 = ProgramBuilder::f32;
426
427 // Returns a u32 loaded from buffer_base + offset.
428 auto load_u32 = [&] {
429 return LoadPrimitive(array_base, offset, buffer, VertexFormat::kUint32);
430 };
431
432 // Returns a i32 loaded from buffer_base + offset.
433 auto load_i32 = [&] { return ctx.dst->Bitcast<i32>(load_u32()); };
434
435 // Returns a u32 loaded from buffer_base + offset + 4.
436 auto load_next_u32 = [&] {
437 return LoadPrimitive(array_base, offset + 4, buffer,
438 VertexFormat::kUint32);
439 };
440
441 // Returns a i32 loaded from buffer_base + offset + 4.
442 auto load_next_i32 = [&] { return ctx.dst->Bitcast<i32>(load_next_u32()); };
443
444 // Returns a u16 loaded from offset, packed in the high 16 bits of a u32.
445 // The low 16 bits are 0.
446 // `min_alignment` must be a power of two.
447 // `offset` must be `min_alignment` bytes aligned.
448 auto load_u16_h = [&] {
449 auto low_u32_offset = offset & ~3u;
450 auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
451 VertexFormat::kUint32);
452 switch (offset & 3) {
453 case 0:
454 return ctx.dst->Shl(low_u32, 16u);
455 case 1:
456 return ctx.dst->And(ctx.dst->Shl(low_u32, 8u), 0xffff0000u);
457 case 2:
458 return ctx.dst->And(low_u32, 0xffff0000u);
459 default: { // 3:
460 auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
461 VertexFormat::kUint32);
462 auto* shr = ctx.dst->Shr(low_u32, 8u);
463 auto* shl = ctx.dst->Shl(high_u32, 24u);
464 return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffff0000u);
465 }
466 }
467 };
468
469 // Returns a u16 loaded from offset, packed in the low 16 bits of a u32.
470 // The high 16 bits are 0.
471 auto load_u16_l = [&] {
472 auto low_u32_offset = offset & ~3u;
473 auto* low_u32 = LoadPrimitive(array_base, low_u32_offset, buffer,
474 VertexFormat::kUint32);
475 switch (offset & 3) {
476 case 0:
477 return ctx.dst->And(low_u32, 0xffffu);
478 case 1:
479 return ctx.dst->And(ctx.dst->Shr(low_u32, 8u), 0xffffu);
480 case 2:
481 return ctx.dst->Shr(low_u32, 16u);
482 default: { // 3:
483 auto* high_u32 = LoadPrimitive(array_base, low_u32_offset + 4, buffer,
484 VertexFormat::kUint32);
485 auto* shr = ctx.dst->Shr(low_u32, 24u);
486 auto* shl = ctx.dst->Shl(high_u32, 8u);
487 return ctx.dst->And(ctx.dst->Or(shl, shr), 0xffffu);
488 }
489 }
490 };
491
492 // Returns a i16 loaded from offset, packed in the high 16 bits of a u32.
493 // The low 16 bits are 0.
494 auto load_i16_h = [&] { return ctx.dst->Bitcast<i32>(load_u16_h()); };
495
496 // Assumptions are made that alignment must be at least as large as the size
497 // of a single component.
498 switch (format) {
499 // Basic primitives
500 case VertexFormat::kUint32:
501 case VertexFormat::kSint32:
502 case VertexFormat::kFloat32:
503 return LoadPrimitive(array_base, offset, buffer, format);
504
505 // Vectors of basic primitives
506 case VertexFormat::kUint32x2:
507 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
508 VertexFormat::kUint32, 2);
509 case VertexFormat::kUint32x3:
510 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
511 VertexFormat::kUint32, 3);
512 case VertexFormat::kUint32x4:
513 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.u32(),
514 VertexFormat::kUint32, 4);
515 case VertexFormat::kSint32x2:
516 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
517 VertexFormat::kSint32, 2);
518 case VertexFormat::kSint32x3:
519 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
520 VertexFormat::kSint32, 3);
521 case VertexFormat::kSint32x4:
522 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.i32(),
523 VertexFormat::kSint32, 4);
524 case VertexFormat::kFloat32x2:
525 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
526 VertexFormat::kFloat32, 2);
527 case VertexFormat::kFloat32x3:
528 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
529 VertexFormat::kFloat32, 3);
530 case VertexFormat::kFloat32x4:
531 return LoadVec(array_base, offset, buffer, 4, ctx.dst->ty.f32(),
532 VertexFormat::kFloat32, 4);
533
534 case VertexFormat::kUint8x2: {
535 // yyxx0000, yyxx0000
536 auto* u16s = ctx.dst->vec2<u32>(load_u16_h());
537 // xx000000, yyxx0000
538 auto* shl = ctx.dst->Shl(u16s, ctx.dst->vec2<u32>(8u, 0u));
539 // 000000xx, 000000yy
540 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
541 }
542 case VertexFormat::kUint8x4: {
543 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
544 auto* u32s = ctx.dst->vec4<u32>(load_u32());
545 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
546 auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
547 // 000000xx, 000000yy, 000000zz, 000000ww
548 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
549 }
550 case VertexFormat::kUint16x2: {
551 // yyyyxxxx, yyyyxxxx
552 auto* u32s = ctx.dst->vec2<u32>(load_u32());
553 // xxxx0000, yyyyxxxx
554 auto* shl = ctx.dst->Shl(u32s, ctx.dst->vec2<u32>(16u, 0u));
555 // 0000xxxx, 0000yyyy
556 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
557 }
558 case VertexFormat::kUint16x4: {
559 // yyyyxxxx, wwwwzzzz
560 auto* u32s = ctx.dst->vec2<u32>(load_u32(), load_next_u32());
561 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
562 auto* xxyy = ctx.dst->MemberAccessor(u32s, "xxyy");
563 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
564 auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
565 // 0000xxxx, 0000yyyy, 0000zzzz, 0000wwww
566 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
567 }
568 case VertexFormat::kSint8x2: {
569 // yyxx0000, yyxx0000
570 auto* i16s = ctx.dst->vec2<i32>(load_i16_h());
571 // xx000000, yyxx0000
572 auto* shl = ctx.dst->Shl(i16s, ctx.dst->vec2<u32>(8u, 0u));
573 // ssssssxx, ssssssyy
574 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(24u));
575 }
576 case VertexFormat::kSint8x4: {
577 // wwzzyyxx, wwzzyyxx, wwzzyyxx, wwzzyyxx
578 auto* i32s = ctx.dst->vec4<i32>(load_i32());
579 // xx000000, yyxx0000, zzyyxx00, wwzzyyxx
580 auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec4<u32>(24u, 16u, 8u, 0u));
581 // ssssssxx, ssssssyy, sssssszz, ssssssww
582 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(24u));
583 }
584 case VertexFormat::kSint16x2: {
585 // yyyyxxxx, yyyyxxxx
586 auto* i32s = ctx.dst->vec2<i32>(load_i32());
587 // xxxx0000, yyyyxxxx
588 auto* shl = ctx.dst->Shl(i32s, ctx.dst->vec2<u32>(16u, 0u));
589 // ssssxxxx, ssssyyyy
590 return ctx.dst->Shr(shl, ctx.dst->vec2<u32>(16u));
591 }
592 case VertexFormat::kSint16x4: {
593 // yyyyxxxx, wwwwzzzz
594 auto* i32s = ctx.dst->vec2<i32>(load_i32(), load_next_i32());
595 // yyyyxxxx, yyyyxxxx, wwwwzzzz, wwwwzzzz
596 auto* xxyy = ctx.dst->MemberAccessor(i32s, "xxyy");
597 // xxxx0000, yyyyxxxx, zzzz0000, wwwwzzzz
598 auto* shl = ctx.dst->Shl(xxyy, ctx.dst->vec4<u32>(16u, 0u, 16u, 0u));
599 // ssssxxxx, ssssyyyy, sssszzzz, sssswwww
600 return ctx.dst->Shr(shl, ctx.dst->vec4<u32>(16u));
601 }
602 case VertexFormat::kUnorm8x2:
603 return ctx.dst->MemberAccessor(
604 ctx.dst->Call("unpack4x8unorm", load_u16_l()), "xy");
605 case VertexFormat::kSnorm8x2:
606 return ctx.dst->MemberAccessor(
607 ctx.dst->Call("unpack4x8snorm", load_u16_l()), "xy");
608 case VertexFormat::kUnorm8x4:
609 return ctx.dst->Call("unpack4x8unorm", load_u32());
610 case VertexFormat::kSnorm8x4:
611 return ctx.dst->Call("unpack4x8snorm", load_u32());
612 case VertexFormat::kUnorm16x2:
613 return ctx.dst->Call("unpack2x16unorm", load_u32());
614 case VertexFormat::kSnorm16x2:
615 return ctx.dst->Call("unpack2x16snorm", load_u32());
616 case VertexFormat::kFloat16x2:
617 return ctx.dst->Call("unpack2x16float", load_u32());
618 case VertexFormat::kUnorm16x4:
619 return ctx.dst->vec4<f32>(
620 ctx.dst->Call("unpack2x16unorm", load_u32()),
621 ctx.dst->Call("unpack2x16unorm", load_next_u32()));
622 case VertexFormat::kSnorm16x4:
623 return ctx.dst->vec4<f32>(
624 ctx.dst->Call("unpack2x16snorm", load_u32()),
625 ctx.dst->Call("unpack2x16snorm", load_next_u32()));
626 case VertexFormat::kFloat16x4:
627 return ctx.dst->vec4<f32>(
628 ctx.dst->Call("unpack2x16float", load_u32()),
629 ctx.dst->Call("unpack2x16float", load_next_u32()));
630 }
631
632 TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
633 << "format " << static_cast<int>(format);
634 return nullptr;
635 }
636
637 /// Generates an expression reading an aligned basic type (u32, i32, f32) from
638 /// a vertex buffer.
639 /// @param array_base the symbol of the variable holding the base array offset
640 /// of the vertex array (each index is 4-bytes).
641 /// @param offset the byte offset of the data from `buffer_base`
642 /// @param buffer the index of the vertex buffer
643 /// @param format VertexFormat::kUint32, VertexFormat::kSint32 or
644 /// VertexFormat::kFloat32
LoadPrimitivetint::transform::__anon572942530111::State645 const ast::Expression* LoadPrimitive(Symbol array_base,
646 uint32_t offset,
647 uint32_t buffer,
648 VertexFormat format) {
649 const ast::Expression* u32 = nullptr;
650 if ((offset & 3) == 0) {
651 // Aligned load.
652
653 const ast ::Expression* index = nullptr;
654 if (offset > 0) {
655 index = ctx.dst->Add(array_base, offset / 4);
656 } else {
657 index = ctx.dst->Expr(array_base);
658 }
659 u32 = ctx.dst->IndexAccessor(
660 ctx.dst->MemberAccessor(GetVertexBufferName(buffer),
661 GetStructBufferName()),
662 index);
663
664 } else {
665 // Unaligned load
666 uint32_t offset_aligned = offset & ~3u;
667 auto* low = LoadPrimitive(array_base, offset_aligned, buffer,
668 VertexFormat::kUint32);
669 auto* high = LoadPrimitive(array_base, offset_aligned + 4u, buffer,
670 VertexFormat::kUint32);
671
672 uint32_t shift = 8u * (offset & 3u);
673
674 auto* low_shr = ctx.dst->Shr(low, shift);
675 auto* high_shl = ctx.dst->Shl(high, 32u - shift);
676 u32 = ctx.dst->Or(low_shr, high_shl);
677 }
678
679 switch (format) {
680 case VertexFormat::kUint32:
681 return u32;
682 case VertexFormat::kSint32:
683 return ctx.dst->Bitcast(ctx.dst->ty.i32(), u32);
684 case VertexFormat::kFloat32:
685 return ctx.dst->Bitcast(ctx.dst->ty.f32(), u32);
686 default:
687 break;
688 }
689 TINT_UNREACHABLE(Transform, ctx.dst->Diagnostics())
690 << "invalid format for LoadPrimitive" << static_cast<int>(format);
691 return nullptr;
692 }
693
694 /// Generates an expression reading a vec2/3/4 from a vertex buffer.
695 /// @param array_base the symbol of the variable holding the base array offset
696 /// of the vertex array (each index is 4-bytes).
697 /// @param offset the byte offset of the data from `buffer_base`
698 /// @param buffer the index of the vertex buffer
699 /// @param element_stride stride between elements, in bytes
700 /// @param base_type underlying AST type
701 /// @param base_format underlying vertex format
702 /// @param count how many elements the vector has
LoadVectint::transform::__anon572942530111::State703 const ast::Expression* LoadVec(Symbol array_base,
704 uint32_t offset,
705 uint32_t buffer,
706 uint32_t element_stride,
707 const ast::Type* base_type,
708 VertexFormat base_format,
709 uint32_t count) {
710 ast::ExpressionList expr_list;
711 for (uint32_t i = 0; i < count; ++i) {
712 // Offset read position by element_stride for each component
713 uint32_t primitive_offset = offset + element_stride * i;
714 expr_list.push_back(
715 LoadPrimitive(array_base, primitive_offset, buffer, base_format));
716 }
717
718 return ctx.dst->Construct(ctx.dst->create<ast::Vector>(base_type, count),
719 std::move(expr_list));
720 }
721
722 /// Process a non-struct entry point parameter.
723 /// Generate function-scope variables for location parameters, and record
724 /// vertex_index and instance_index builtins if present.
725 /// @param func the entry point function
726 /// @param param the parameter to process
ProcessNonStructParametertint::transform::__anon572942530111::State727 void ProcessNonStructParameter(const ast::Function* func,
728 const ast::Variable* param) {
729 if (auto* location =
730 ast::GetDecoration<ast::LocationDecoration>(param->decorations)) {
731 // Create a function-scope variable to replace the parameter.
732 auto func_var_sym = ctx.Clone(param->symbol);
733 auto* func_var_type = ctx.Clone(param->type);
734 auto* func_var = ctx.dst->Var(func_var_sym, func_var_type);
735 ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
736 // Capture mapping from location to the new variable.
737 LocationInfo info;
738 info.expr = [this, func_var]() { return ctx.dst->Expr(func_var); };
739 info.type = ctx.src->Sem().Get(param)->Type();
740 location_info[location->value] = info;
741 } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
742 param->decorations)) {
743 // Check for existing vertex_index and instance_index builtins.
744 if (builtin->builtin == ast::Builtin::kVertexIndex) {
745 vertex_index_expr = [this, param]() {
746 return ctx.dst->Expr(ctx.Clone(param->symbol));
747 };
748 } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
749 instance_index_expr = [this, param]() {
750 return ctx.dst->Expr(ctx.Clone(param->symbol));
751 };
752 }
753 new_function_parameters.push_back(ctx.Clone(param));
754 } else {
755 TINT_ICE(Transform, ctx.dst->Diagnostics())
756 << "Invalid entry point parameter";
757 }
758 }
759
760 /// Process a struct entry point parameter.
761 /// If the struct has members with location attributes, push the parameter to
762 /// a function-scope variable and create a new struct parameter without those
763 /// attributes. Record expressions for members that are vertex_index and
764 /// instance_index builtins.
765 /// @param func the entry point function
766 /// @param param the parameter to process
767 /// @param struct_ty the structure type
ProcessStructParametertint::transform::__anon572942530111::State768 void ProcessStructParameter(const ast::Function* func,
769 const ast::Variable* param,
770 const ast::Struct* struct_ty) {
771 auto param_sym = ctx.Clone(param->symbol);
772
773 // Process the struct members.
774 bool has_locations = false;
775 ast::StructMemberList members_to_clone;
776 for (auto* member : struct_ty->members) {
777 auto member_sym = ctx.Clone(member->symbol);
778 std::function<const ast::Expression*()> member_expr = [this, param_sym,
779 member_sym]() {
780 return ctx.dst->MemberAccessor(param_sym, member_sym);
781 };
782
783 if (auto* location = ast::GetDecoration<ast::LocationDecoration>(
784 member->decorations)) {
785 // Capture mapping from location to struct member.
786 LocationInfo info;
787 info.expr = member_expr;
788 info.type = ctx.src->Sem().Get(member)->Type();
789 location_info[location->value] = info;
790 has_locations = true;
791 } else if (auto* builtin = ast::GetDecoration<ast::BuiltinDecoration>(
792 member->decorations)) {
793 // Check for existing vertex_index and instance_index builtins.
794 if (builtin->builtin == ast::Builtin::kVertexIndex) {
795 vertex_index_expr = member_expr;
796 } else if (builtin->builtin == ast::Builtin::kInstanceIndex) {
797 instance_index_expr = member_expr;
798 }
799 members_to_clone.push_back(member);
800 } else {
801 TINT_ICE(Transform, ctx.dst->Diagnostics())
802 << "Invalid entry point parameter";
803 }
804 }
805
806 if (!has_locations) {
807 // Nothing to do.
808 new_function_parameters.push_back(ctx.Clone(param));
809 return;
810 }
811
812 // Create a function-scope variable to replace the parameter.
813 auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->type));
814 ctx.InsertFront(func->body->statements, ctx.dst->Decl(func_var));
815
816 if (!members_to_clone.empty()) {
817 // Create a new struct without the location attributes.
818 ast::StructMemberList new_members;
819 for (auto* member : members_to_clone) {
820 auto member_sym = ctx.Clone(member->symbol);
821 auto* member_type = ctx.Clone(member->type);
822 auto member_decos = ctx.Clone(member->decorations);
823 new_members.push_back(
824 ctx.dst->Member(member_sym, member_type, std::move(member_decos)));
825 }
826 auto* new_struct = ctx.dst->Structure(ctx.dst->Sym(), new_members);
827
828 // Create a new function parameter with this struct.
829 auto* new_param =
830 ctx.dst->Param(ctx.dst->Sym(), ctx.dst->ty.Of(new_struct));
831 new_function_parameters.push_back(new_param);
832
833 // Copy values from the new parameter to the function-scope variable.
834 for (auto* member : members_to_clone) {
835 auto member_name = ctx.Clone(member->symbol);
836 ctx.InsertFront(
837 func->body->statements,
838 ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name),
839 ctx.dst->MemberAccessor(new_param, member_name)));
840 }
841 }
842 }
843
844 /// Process an entry point function.
845 /// @param func the entry point function
Processtint::transform::__anon572942530111::State846 void Process(const ast::Function* func) {
847 if (func->body->Empty()) {
848 return;
849 }
850
851 // Process entry point parameters.
852 for (auto* param : func->params) {
853 auto* sem = ctx.src->Sem().Get(param);
854 if (auto* str = sem->Type()->As<sem::Struct>()) {
855 ProcessStructParameter(func, param, str->Declaration());
856 } else {
857 ProcessNonStructParameter(func, param);
858 }
859 }
860
861 // Insert new parameters for vertex_index and instance_index if needed.
862 if (!vertex_index_expr) {
863 for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
864 if (layout.step_mode == VertexStepMode::kVertex) {
865 auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index");
866 new_function_parameters.push_back(
867 ctx.dst->Param(name, ctx.dst->ty.u32(),
868 {ctx.dst->Builtin(ast::Builtin::kVertexIndex)}));
869 vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); };
870 break;
871 }
872 }
873 }
874 if (!instance_index_expr) {
875 for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) {
876 if (layout.step_mode == VertexStepMode::kInstance) {
877 auto name = ctx.dst->Symbols().New("tint_pulling_instance_index");
878 new_function_parameters.push_back(
879 ctx.dst->Param(name, ctx.dst->ty.u32(),
880 {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)}));
881 instance_index_expr = [this, name]() { return ctx.dst->Expr(name); };
882 break;
883 }
884 }
885 }
886
887 // Generate vertex pulling preamble.
888 if (auto* block = CreateVertexPullingPreamble()) {
889 ctx.InsertFront(func->body->statements, block);
890 }
891
892 // Rewrite the function header with the new parameters.
893 auto func_sym = ctx.Clone(func->symbol);
894 auto* ret_type = ctx.Clone(func->return_type);
895 auto* body = ctx.Clone(func->body);
896 auto decos = ctx.Clone(func->decorations);
897 auto ret_decos = ctx.Clone(func->return_type_decorations);
898 auto* new_func = ctx.dst->create<ast::Function>(
899 func->source, func_sym, new_function_parameters, ret_type, body,
900 std::move(decos), std::move(ret_decos));
901 ctx.Replace(func, new_func);
902 }
903 };
904
905 } // namespace
906
907 VertexPulling::VertexPulling() = default;
908 VertexPulling::~VertexPulling() = default;
909
Run(CloneContext & ctx,const DataMap & inputs,DataMap &)910 void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) {
911 auto cfg = cfg_;
912 if (auto* cfg_data = inputs.Get<Config>()) {
913 cfg = *cfg_data;
914 }
915
916 // Find entry point
917 auto* func = ctx.src->AST().Functions().Find(
918 ctx.src->Symbols().Get(cfg.entry_point_name),
919 ast::PipelineStage::kVertex);
920 if (func == nullptr) {
921 ctx.dst->Diagnostics().add_error(diag::System::Transform,
922 "Vertex stage entry point not found");
923 return;
924 }
925
926 // TODO(idanr): Need to check shader locations in descriptor cover all
927 // attributes
928
929 // TODO(idanr): Make sure we covered all error cases, to guarantee the
930 // following stages will pass
931
932 State state{ctx, cfg};
933 state.AddVertexStorageBuffers();
934 state.Process(func);
935
936 ctx.Clone();
937 }
938
939 VertexPulling::Config::Config() = default;
940 VertexPulling::Config::Config(const Config&) = default;
941 VertexPulling::Config::~Config() = default;
942 VertexPulling::Config& VertexPulling::Config::operator=(const Config&) =
943 default;
944
945 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor() = default;
946
VertexBufferLayoutDescriptor(uint32_t in_array_stride,VertexStepMode in_step_mode,std::vector<VertexAttributeDescriptor> in_attributes)947 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
948 uint32_t in_array_stride,
949 VertexStepMode in_step_mode,
950 std::vector<VertexAttributeDescriptor> in_attributes)
951 : array_stride(in_array_stride),
952 step_mode(in_step_mode),
953 attributes(std::move(in_attributes)) {}
954
955 VertexBufferLayoutDescriptor::VertexBufferLayoutDescriptor(
956 const VertexBufferLayoutDescriptor& other) = default;
957
958 VertexBufferLayoutDescriptor& VertexBufferLayoutDescriptor::operator=(
959 const VertexBufferLayoutDescriptor& other) = default;
960
961 VertexBufferLayoutDescriptor::~VertexBufferLayoutDescriptor() = default;
962
963 } // namespace transform
964 } // namespace tint
965