1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/ir/SkSLSwizzle.h"
9
10 #include "include/private/SkTOptional.h"
11 #include "include/sksl/SkSLErrorReporter.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLConstantFolder.h"
14 #include "src/sksl/SkSLProgramSettings.h"
15 #include "src/sksl/ir/SkSLConstructor.h"
16 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
17 #include "src/sksl/ir/SkSLConstructorSplat.h"
18 #include "src/sksl/ir/SkSLLiteral.h"
19
20 namespace SkSL {
21
validate_swizzle_domain(const ComponentArray & fields)22 static bool validate_swizzle_domain(const ComponentArray& fields) {
23 enum SwizzleDomain {
24 kCoordinate,
25 kColor,
26 kUV,
27 kRectangle,
28 };
29
30 skstd::optional<SwizzleDomain> domain;
31
32 for (int8_t field : fields) {
33 SwizzleDomain fieldDomain;
34 switch (field) {
35 case SwizzleComponent::X:
36 case SwizzleComponent::Y:
37 case SwizzleComponent::Z:
38 case SwizzleComponent::W:
39 fieldDomain = kCoordinate;
40 break;
41 case SwizzleComponent::R:
42 case SwizzleComponent::G:
43 case SwizzleComponent::B:
44 case SwizzleComponent::A:
45 fieldDomain = kColor;
46 break;
47 case SwizzleComponent::S:
48 case SwizzleComponent::T:
49 case SwizzleComponent::P:
50 case SwizzleComponent::Q:
51 fieldDomain = kUV;
52 break;
53 case SwizzleComponent::UL:
54 case SwizzleComponent::UT:
55 case SwizzleComponent::UR:
56 case SwizzleComponent::UB:
57 fieldDomain = kRectangle;
58 break;
59 case SwizzleComponent::ZERO:
60 case SwizzleComponent::ONE:
61 continue;
62 default:
63 return false;
64 }
65
66 if (!domain.has_value()) {
67 domain = fieldDomain;
68 } else if (domain != fieldDomain) {
69 return false;
70 }
71 }
72
73 return true;
74 }
75
mask_char(int8_t component)76 static char mask_char(int8_t component) {
77 switch (component) {
78 case SwizzleComponent::X: return 'x';
79 case SwizzleComponent::Y: return 'y';
80 case SwizzleComponent::Z: return 'z';
81 case SwizzleComponent::W: return 'w';
82 case SwizzleComponent::R: return 'r';
83 case SwizzleComponent::G: return 'g';
84 case SwizzleComponent::B: return 'b';
85 case SwizzleComponent::A: return 'a';
86 case SwizzleComponent::S: return 's';
87 case SwizzleComponent::T: return 't';
88 case SwizzleComponent::P: return 'p';
89 case SwizzleComponent::Q: return 'q';
90 case SwizzleComponent::UL: return 'L';
91 case SwizzleComponent::UT: return 'T';
92 case SwizzleComponent::UR: return 'R';
93 case SwizzleComponent::UB: return 'B';
94 case SwizzleComponent::ZERO: return '0';
95 case SwizzleComponent::ONE: return '1';
96 default: SkUNREACHABLE;
97 }
98 }
99
mask_string(const ComponentArray & components)100 static String mask_string(const ComponentArray& components) {
101 String result;
102 for (int8_t component : components) {
103 result += mask_char(component);
104 }
105 return result;
106 }
107
optimize_constructor_swizzle(const Context & context,const AnyConstructor & base,ComponentArray components)108 static std::unique_ptr<Expression> optimize_constructor_swizzle(const Context& context,
109 const AnyConstructor& base,
110 ComponentArray components) {
111 auto baseArguments = base.argumentSpan();
112 std::unique_ptr<Expression> replacement;
113 const Type& exprType = base.type();
114 const Type& componentType = exprType.componentType();
115 int swizzleSize = components.size();
116
117 // Swizzles can duplicate some elements and discard others, e.g.
118 // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
119 // - Expressions with side effects need to occur exactly once, even if they would otherwise be
120 // swizzle-eliminated
121 // - Non-trivial expressions should not be repeated, but elimination is OK.
122 //
123 // Look up the argument for the constructor at each index. This is typically simple but for
124 // weird cases like `half4(bar.yz, half2(foo))`, it can be harder than it seems. This example
125 // would result in:
126 // argMap[0] = {.fArgIndex = 0, .fComponent = 0} (bar.yz .x)
127 // argMap[1] = {.fArgIndex = 0, .fComponent = 1} (bar.yz .y)
128 // argMap[2] = {.fArgIndex = 1, .fComponent = 0} (half2(foo) .x)
129 // argMap[3] = {.fArgIndex = 1, .fComponent = 1} (half2(foo) .y)
130 struct ConstructorArgMap {
131 int8_t fArgIndex;
132 int8_t fComponent;
133 };
134
135 int numConstructorArgs = base.type().columns();
136 ConstructorArgMap argMap[4] = {};
137 int writeIdx = 0;
138 for (int argIdx = 0; argIdx < (int)baseArguments.size(); ++argIdx) {
139 const Expression& arg = *baseArguments[argIdx];
140 const Type& argType = arg.type();
141
142 if (!argType.isScalar() && !argType.isVector()) {
143 return nullptr;
144 }
145
146 int argSlots = argType.slotCount();
147 for (int componentIdx = 0; componentIdx < argSlots; ++componentIdx) {
148 argMap[writeIdx].fArgIndex = argIdx;
149 argMap[writeIdx].fComponent = componentIdx;
150 ++writeIdx;
151 }
152 }
153 SkASSERT(writeIdx == numConstructorArgs);
154
155 // Count up the number of times each constructor argument is used by the swizzle.
156 // `half4(bar.yz, half2(foo)).xwxy` -> { 3, 1 }
157 // - bar.yz is referenced 3 times, by `.x_xy`
158 // - half(foo) is referenced 1 time, by `._w__`
159 int8_t exprUsed[4] = {};
160 for (int8_t c : components) {
161 exprUsed[argMap[c].fArgIndex]++;
162 }
163
164 for (int index = 0; index < numConstructorArgs; ++index) {
165 int8_t constructorArgIndex = argMap[index].fArgIndex;
166 const Expression& baseArg = *baseArguments[constructorArgIndex];
167
168 // Check that non-trivial expressions are not swizzled in more than once.
169 if (exprUsed[constructorArgIndex] > 1 && !Analysis::IsTrivialExpression(baseArg)) {
170 return nullptr;
171 }
172 // Check that side-effect-bearing expressions are swizzled in exactly once.
173 if (exprUsed[constructorArgIndex] != 1 && baseArg.hasSideEffects()) {
174 return nullptr;
175 }
176 }
177
178 struct ReorderedArgument {
179 int8_t fArgIndex;
180 ComponentArray fComponents;
181 };
182 SkSTArray<4, ReorderedArgument> reorderedArgs;
183 for (int8_t c : components) {
184 const ConstructorArgMap& argument = argMap[c];
185 const Expression& baseArg = *baseArguments[argument.fArgIndex];
186
187 if (baseArg.type().isScalar()) {
188 // This argument is a scalar; add it to the list as-is.
189 SkASSERT(argument.fComponent == 0);
190 reorderedArgs.push_back({argument.fArgIndex,
191 ComponentArray{}});
192 } else {
193 // This argument is a component from a vector.
194 SkASSERT(baseArg.type().isVector());
195 SkASSERT(argument.fComponent < baseArg.type().columns());
196 if (reorderedArgs.empty() ||
197 reorderedArgs.back().fArgIndex != argument.fArgIndex) {
198 // This can't be combined with the previous argument. Add a new one.
199 reorderedArgs.push_back({argument.fArgIndex,
200 ComponentArray{argument.fComponent}});
201 } else {
202 // Since we know this argument uses components, it should already have at least one
203 // component set.
204 SkASSERT(!reorderedArgs.back().fComponents.empty());
205 // Build up the current argument with one more component.
206 reorderedArgs.back().fComponents.push_back(argument.fComponent);
207 }
208 }
209 }
210
211 // Convert our reordered argument list to an actual array of expressions, with the new order and
212 // any new inner swizzles that need to be applied.
213 ExpressionArray newArgs;
214 newArgs.reserve_back(swizzleSize);
215 for (const ReorderedArgument& reorderedArg : reorderedArgs) {
216 std::unique_ptr<Expression> newArg =
217 baseArguments[reorderedArg.fArgIndex]->clone();
218
219 if (reorderedArg.fComponents.empty()) {
220 newArgs.push_back(std::move(newArg));
221 } else {
222 newArgs.push_back(Swizzle::Make(context, std::move(newArg),
223 reorderedArg.fComponents));
224 }
225 }
226
227 // Wrap the new argument list in a constructor.
228 auto ctor = Constructor::Convert(context,
229 base.fLine,
230 componentType.toCompound(context, swizzleSize, /*rows=*/1),
231 std::move(newArgs));
232 SkASSERT(ctor);
233 return ctor;
234 }
235
Convert(const Context & context,std::unique_ptr<Expression> base,skstd::string_view maskString)236 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
237 std::unique_ptr<Expression> base,
238 skstd::string_view maskString) {
239 ComponentArray components;
240 for (char field : maskString) {
241 switch (field) {
242 case '0': components.push_back(SwizzleComponent::ZERO); break;
243 case '1': components.push_back(SwizzleComponent::ONE); break;
244 case 'x': components.push_back(SwizzleComponent::X); break;
245 case 'r': components.push_back(SwizzleComponent::R); break;
246 case 's': components.push_back(SwizzleComponent::S); break;
247 case 'L': components.push_back(SwizzleComponent::UL); break;
248 case 'y': components.push_back(SwizzleComponent::Y); break;
249 case 'g': components.push_back(SwizzleComponent::G); break;
250 case 't': components.push_back(SwizzleComponent::T); break;
251 case 'T': components.push_back(SwizzleComponent::UT); break;
252 case 'z': components.push_back(SwizzleComponent::Z); break;
253 case 'b': components.push_back(SwizzleComponent::B); break;
254 case 'p': components.push_back(SwizzleComponent::P); break;
255 case 'R': components.push_back(SwizzleComponent::UR); break;
256 case 'w': components.push_back(SwizzleComponent::W); break;
257 case 'a': components.push_back(SwizzleComponent::A); break;
258 case 'q': components.push_back(SwizzleComponent::Q); break;
259 case 'B': components.push_back(SwizzleComponent::UB); break;
260 default:
261 context.fErrors->error(base->fLine,
262 String::printf("invalid swizzle component '%c'", field));
263 return nullptr;
264 }
265 }
266 return Convert(context, std::move(base), std::move(components));
267 }
268
269 // Swizzles are complicated due to constant components. The most difficult case is a mask like
270 // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that evaluates
271 // 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0') together and use a
272 // secondary swizzle to put them back into the right order, so in this case we end up with
273 // 'float4(base.xw, 1, 0).xzyw'.
Convert(const Context & context,std::unique_ptr<Expression> base,ComponentArray inComponents)274 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
275 std::unique_ptr<Expression> base,
276 ComponentArray inComponents) {
277 if (!validate_swizzle_domain(inComponents)) {
278 context.fErrors->error(base->fLine,
279 "invalid swizzle mask '" + mask_string(inComponents) + "'");
280 return nullptr;
281 }
282
283 const int line = base->fLine;
284 const Type& baseType = base->type();
285
286 if (!baseType.isVector() && !baseType.isScalar()) {
287 context.fErrors->error(
288 line, "cannot swizzle value of type '" + baseType.displayName() + "'");
289 return nullptr;
290 }
291
292 if (inComponents.count() > 4) {
293 context.fErrors->error(line,
294 "too many components in swizzle mask '" + mask_string(inComponents) + "'");
295 return nullptr;
296 }
297
298 ComponentArray maskComponents;
299 bool foundXYZW = false;
300 for (int i = 0; i < inComponents.count(); ++i) {
301 switch (inComponents[i]) {
302 case SwizzleComponent::ZERO:
303 case SwizzleComponent::ONE:
304 // Skip over constant fields for now.
305 break;
306 case SwizzleComponent::X:
307 case SwizzleComponent::R:
308 case SwizzleComponent::S:
309 case SwizzleComponent::UL:
310 foundXYZW = true;
311 maskComponents.push_back(SwizzleComponent::X);
312 break;
313 case SwizzleComponent::Y:
314 case SwizzleComponent::G:
315 case SwizzleComponent::T:
316 case SwizzleComponent::UT:
317 foundXYZW = true;
318 if (baseType.columns() >= 2) {
319 maskComponents.push_back(SwizzleComponent::Y);
320 break;
321 }
322 [[fallthrough]];
323 case SwizzleComponent::Z:
324 case SwizzleComponent::B:
325 case SwizzleComponent::P:
326 case SwizzleComponent::UR:
327 foundXYZW = true;
328 if (baseType.columns() >= 3) {
329 maskComponents.push_back(SwizzleComponent::Z);
330 break;
331 }
332 [[fallthrough]];
333 case SwizzleComponent::W:
334 case SwizzleComponent::A:
335 case SwizzleComponent::Q:
336 case SwizzleComponent::UB:
337 foundXYZW = true;
338 if (baseType.columns() >= 4) {
339 maskComponents.push_back(SwizzleComponent::W);
340 break;
341 }
342 [[fallthrough]];
343 default:
344 // The swizzle component references a field that doesn't exist in the base type.
345 context.fErrors->error(line,
346 String::printf("invalid swizzle component '%c'",
347 mask_char(inComponents[i])));
348 return nullptr;
349 }
350 }
351
352 if (!foundXYZW) {
353 context.fErrors->error(line, "swizzle must refer to base expression");
354 return nullptr;
355 }
356
357 // Coerce literals in expressions such as `(12345).xxx` to their actual type.
358 base = baseType.scalarTypeForLiteral().coerceExpression(std::move(base), context);
359 if (!base) {
360 return nullptr;
361 }
362
363 // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
364 // scalar.xxx -> type3(scalar)
365 // scalar.x0x0 -> type2(scalar)
366 // vector.zyx -> vector.zyx
367 // vector.x0y0 -> vector.xy
368 std::unique_ptr<Expression> expr = Swizzle::Make(context, std::move(base), maskComponents);
369
370 // If we have processed the entire swizzle, we're done.
371 if (maskComponents.count() == inComponents.count()) {
372 return expr;
373 }
374
375 // Now we create a constructor that has the correct number of elements for the final swizzle,
376 // with all fields at the start. It's not finished yet; constants we need will be added below.
377 // scalar.x0x0 -> type4(type2(x), ...)
378 // vector.y111 -> type4(vector.y, ...)
379 // vector.z10x -> type4(vector.zx, ...)
380 //
381 // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
382 ExpressionArray constructorArgs;
383 constructorArgs.reserve_back(3);
384 constructorArgs.push_back(std::move(expr));
385
386 // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
387 // need are also tacked on to the end of the constructor.
388 // scalar.x0x0 -> type4(type2(x), 0).xyxy
389 // vector.y111 -> type4(vector.y, 1).xyyy
390 // vector.z10x -> type4(vector.zx, 1, 0).xzwy
391 const Type* scalarType = &baseType.componentType();
392 ComponentArray swizzleComponents;
393 int maskFieldIdx = 0;
394 int constantFieldIdx = maskComponents.size();
395 int constantZeroIdx = -1, constantOneIdx = -1;
396
397 for (int i = 0; i < inComponents.count(); i++) {
398 switch (inComponents[i]) {
399 case SwizzleComponent::ZERO:
400 if (constantZeroIdx == -1) {
401 // Synthesize a 'type(0)' argument at the end of the constructor.
402 constructorArgs.push_back(ConstructorScalarCast::Make(
403 context, line, *scalarType,
404 Literal::MakeInt(context, line, /*value=*/0)));
405 constantZeroIdx = constantFieldIdx++;
406 }
407 swizzleComponents.push_back(constantZeroIdx);
408 break;
409 case SwizzleComponent::ONE:
410 if (constantOneIdx == -1) {
411 // Synthesize a 'type(1)' argument at the end of the constructor.
412 constructorArgs.push_back(ConstructorScalarCast::Make(
413 context, line, *scalarType,
414 Literal::MakeInt(context, line, /*value=*/1)));
415 constantOneIdx = constantFieldIdx++;
416 }
417 swizzleComponents.push_back(constantOneIdx);
418 break;
419 default:
420 // The non-constant fields are already in the expected order.
421 swizzleComponents.push_back(maskFieldIdx++);
422 break;
423 }
424 }
425
426 expr = Constructor::Convert(context, line,
427 scalarType->toCompound(context, constantFieldIdx, /*rows=*/1),
428 std::move(constructorArgs));
429 if (!expr) {
430 return nullptr;
431 }
432
433 return Swizzle::Make(context, std::move(expr), swizzleComponents);
434 }
435
Make(const Context & context,std::unique_ptr<Expression> expr,ComponentArray components)436 std::unique_ptr<Expression> Swizzle::Make(const Context& context,
437 std::unique_ptr<Expression> expr,
438 ComponentArray components) {
439 const Type& exprType = expr->type();
440 SkASSERTF(exprType.isVector() || exprType.isScalar(),
441 "cannot swizzle type '%s'", exprType.description().c_str());
442 SkASSERT(components.count() >= 1 && components.count() <= 4);
443
444 // Confirm that the component array only contains X/Y/Z/W. (Call MakeWith01 if you want support
445 // for ZERO and ONE. Once initial IR generation is complete, no swizzles should have zeros or
446 // ones in them.)
447 SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
448 return component >= SwizzleComponent::X &&
449 component <= SwizzleComponent::W;
450 }));
451
452 // SkSL supports splatting a scalar via `scalar.xxxx`, but not all versions of GLSL allow this.
453 // Replace swizzles with equivalent splat constructors (`scalar.xxx` --> `half3(value)`).
454 if (exprType.isScalar()) {
455 int line = expr->fLine;
456 return ConstructorSplat::Make(context, line,
457 exprType.toCompound(context, components.size(), /*rows=*/1),
458 std::move(expr));
459 }
460
461 // Detect identity swizzles like `color.rgba` and return the base-expression as-is.
462 if (components.count() == exprType.columns()) {
463 bool identity = true;
464 for (int i = 0; i < components.count(); ++i) {
465 if (components[i] != i) {
466 identity = false;
467 break;
468 }
469 }
470 if (identity) {
471 return expr;
472 }
473 }
474
475 // Optimize swizzles of swizzles, e.g. replace `foo.argb.rggg` with `foo.arrr`.
476 if (expr->is<Swizzle>()) {
477 Swizzle& base = expr->as<Swizzle>();
478 ComponentArray combined;
479 for (int8_t c : components) {
480 combined.push_back(base.components()[c]);
481 }
482
483 // It may actually be possible to further simplify this swizzle. Go again.
484 // (e.g. `color.abgr.abgr` --> `color.rgba` --> `color`.)
485 return Swizzle::Make(context, std::move(base.base()), combined);
486 }
487
488 // If we are swizzling a constant expression, we can use its value instead here (so that
489 // swizzles like `colorWhite.x` can be simplified to `1`).
490 const Expression* value = ConstantFolder::GetConstantValueForVariable(*expr);
491
492 // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
493 // optimized to just `scalar`. The swizzle components don't actually matter, as every field
494 // in a splat constructor holds the same value.
495 if (value->is<ConstructorSplat>()) {
496 const ConstructorSplat& splat = value->as<ConstructorSplat>();
497 return ConstructorSplat::Make(
498 context, splat.fLine,
499 splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
500 splat.argument()->clone());
501 }
502
503 // Optimize swizzles of constructors.
504 if (value->isAnyConstructor()) {
505 const AnyConstructor& ctor = value->asAnyConstructor();
506 if (auto replacement = optimize_constructor_swizzle(context, ctor, components)) {
507 return replacement;
508 }
509 }
510
511 // The swizzle could not be simplified, so apply the requested swizzle to the base expression.
512 return std::make_unique<Swizzle>(context, std::move(expr), components);
513 }
514
515 } // namespace SkSL
516