1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #include "src/sksl/ir/SkSLSwizzle.h"
9
10 #include "include/core/SkSpan.h"
11 #include "include/private/SkSLString.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/sksl/SkSLErrorReporter.h"
14 #include "include/sksl/SkSLOperator.h"
15 #include "src/sksl/SkSLAnalysis.h"
16 #include "src/sksl/SkSLConstantFolder.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/ir/SkSLConstructorCompound.h"
19 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
20 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLLiteral.h"
23
24 #include <algorithm>
25 #include <cstddef>
26 #include <cstdint>
27 #include <optional>
28
29 namespace SkSL {
30
validate_swizzle_domain(const ComponentArray & fields)31 static bool validate_swizzle_domain(const ComponentArray& fields) {
32 enum SwizzleDomain {
33 kCoordinate,
34 kColor,
35 kUV,
36 kRectangle,
37 };
38
39 std::optional<SwizzleDomain> domain;
40
41 for (int8_t field : fields) {
42 SwizzleDomain fieldDomain;
43 switch (field) {
44 case SwizzleComponent::X:
45 case SwizzleComponent::Y:
46 case SwizzleComponent::Z:
47 case SwizzleComponent::W:
48 fieldDomain = kCoordinate;
49 break;
50 case SwizzleComponent::R:
51 case SwizzleComponent::G:
52 case SwizzleComponent::B:
53 case SwizzleComponent::A:
54 fieldDomain = kColor;
55 break;
56 case SwizzleComponent::S:
57 case SwizzleComponent::T:
58 case SwizzleComponent::P:
59 case SwizzleComponent::Q:
60 fieldDomain = kUV;
61 break;
62 case SwizzleComponent::UL:
63 case SwizzleComponent::UT:
64 case SwizzleComponent::UR:
65 case SwizzleComponent::UB:
66 fieldDomain = kRectangle;
67 break;
68 case SwizzleComponent::ZERO:
69 case SwizzleComponent::ONE:
70 continue;
71 default:
72 return false;
73 }
74
75 if (!domain.has_value()) {
76 domain = fieldDomain;
77 } else if (domain != fieldDomain) {
78 return false;
79 }
80 }
81
82 return true;
83 }
84
mask_char(int8_t component)85 static char mask_char(int8_t component) {
86 switch (component) {
87 case SwizzleComponent::X: return 'x';
88 case SwizzleComponent::Y: return 'y';
89 case SwizzleComponent::Z: return 'z';
90 case SwizzleComponent::W: return 'w';
91 case SwizzleComponent::R: return 'r';
92 case SwizzleComponent::G: return 'g';
93 case SwizzleComponent::B: return 'b';
94 case SwizzleComponent::A: return 'a';
95 case SwizzleComponent::S: return 's';
96 case SwizzleComponent::T: return 't';
97 case SwizzleComponent::P: return 'p';
98 case SwizzleComponent::Q: return 'q';
99 case SwizzleComponent::UL: return 'L';
100 case SwizzleComponent::UT: return 'T';
101 case SwizzleComponent::UR: return 'R';
102 case SwizzleComponent::UB: return 'B';
103 case SwizzleComponent::ZERO: return '0';
104 case SwizzleComponent::ONE: return '1';
105 default: SkUNREACHABLE;
106 }
107 }
108
mask_string(const ComponentArray & components)109 static std::string mask_string(const ComponentArray& components) {
110 std::string result;
111 for (int8_t component : components) {
112 result += mask_char(component);
113 }
114 return result;
115 }
116
optimize_constructor_swizzle(const Context & context,Position pos,const ConstructorCompound & base,ComponentArray components)117 static std::unique_ptr<Expression> optimize_constructor_swizzle(const Context& context,
118 Position pos,
119 const ConstructorCompound& base,
120 ComponentArray components) {
121 auto baseArguments = base.argumentSpan();
122 std::unique_ptr<Expression> replacement;
123 const Type& exprType = base.type();
124 const Type& componentType = exprType.componentType();
125 int swizzleSize = components.size();
126
127 // Swizzles can duplicate some elements and discard others, e.g.
128 // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
129 // - Expressions with side effects need to occur exactly once, even if they would otherwise be
130 // swizzle-eliminated
131 // - Non-trivial expressions should not be repeated, but elimination is OK.
132 //
133 // Look up the argument for the constructor at each index. This is typically simple but for
134 // weird cases like `half4(bar.yz, half2(foo))`, it can be harder than it seems. This example
135 // would result in:
136 // argMap[0] = {.fArgIndex = 0, .fComponent = 0} (bar.yz .x)
137 // argMap[1] = {.fArgIndex = 0, .fComponent = 1} (bar.yz .y)
138 // argMap[2] = {.fArgIndex = 1, .fComponent = 0} (half2(foo) .x)
139 // argMap[3] = {.fArgIndex = 1, .fComponent = 1} (half2(foo) .y)
140 struct ConstructorArgMap {
141 int8_t fArgIndex;
142 int8_t fComponent;
143 };
144
145 int numConstructorArgs = base.type().columns();
146 ConstructorArgMap argMap[4] = {};
147 int writeIdx = 0;
148 for (int argIdx = 0; argIdx < (int)baseArguments.size(); ++argIdx) {
149 const Expression& arg = *baseArguments[argIdx];
150 const Type& argType = arg.type();
151
152 if (!argType.isScalar() && !argType.isVector()) {
153 return nullptr;
154 }
155
156 int argSlots = argType.slotCount();
157 for (int componentIdx = 0; componentIdx < argSlots; ++componentIdx) {
158 argMap[writeIdx].fArgIndex = argIdx;
159 argMap[writeIdx].fComponent = componentIdx;
160 ++writeIdx;
161 }
162 }
163 SkASSERT(writeIdx == numConstructorArgs);
164
165 // Count up the number of times each constructor argument is used by the swizzle.
166 // `half4(bar.yz, half2(foo)).xwxy` -> { 3, 1 }
167 // - bar.yz is referenced 3 times, by `.x_xy`
168 // - half(foo) is referenced 1 time, by `._w__`
169 int8_t exprUsed[4] = {};
170 for (int8_t c : components) {
171 exprUsed[argMap[c].fArgIndex]++;
172 }
173
174 for (int index = 0; index < numConstructorArgs; ++index) {
175 int8_t constructorArgIndex = argMap[index].fArgIndex;
176 const Expression& baseArg = *baseArguments[constructorArgIndex];
177
178 // Check that non-trivial expressions are not swizzled in more than once.
179 if (exprUsed[constructorArgIndex] > 1 && !Analysis::IsTrivialExpression(baseArg)) {
180 return nullptr;
181 }
182 // Check that side-effect-bearing expressions are swizzled in exactly once.
183 if (exprUsed[constructorArgIndex] != 1 && Analysis::HasSideEffects(baseArg)) {
184 return nullptr;
185 }
186 }
187
188 struct ReorderedArgument {
189 int8_t fArgIndex;
190 ComponentArray fComponents;
191 };
192 SkSTArray<4, ReorderedArgument> reorderedArgs;
193 for (int8_t c : components) {
194 const ConstructorArgMap& argument = argMap[c];
195 const Expression& baseArg = *baseArguments[argument.fArgIndex];
196
197 if (baseArg.type().isScalar()) {
198 // This argument is a scalar; add it to the list as-is.
199 SkASSERT(argument.fComponent == 0);
200 reorderedArgs.push_back({argument.fArgIndex,
201 ComponentArray{}});
202 } else {
203 // This argument is a component from a vector.
204 SkASSERT(baseArg.type().isVector());
205 SkASSERT(argument.fComponent < baseArg.type().columns());
206 if (reorderedArgs.empty() ||
207 reorderedArgs.back().fArgIndex != argument.fArgIndex) {
208 // This can't be combined with the previous argument. Add a new one.
209 reorderedArgs.push_back({argument.fArgIndex,
210 ComponentArray{argument.fComponent}});
211 } else {
212 // Since we know this argument uses components, it should already have at least one
213 // component set.
214 SkASSERT(!reorderedArgs.back().fComponents.empty());
215 // Build up the current argument with one more component.
216 reorderedArgs.back().fComponents.push_back(argument.fComponent);
217 }
218 }
219 }
220
221 // Convert our reordered argument list to an actual array of expressions, with the new order and
222 // any new inner swizzles that need to be applied.
223 ExpressionArray newArgs;
224 newArgs.reserve_back(swizzleSize);
225 for (const ReorderedArgument& reorderedArg : reorderedArgs) {
226 std::unique_ptr<Expression> newArg =
227 baseArguments[reorderedArg.fArgIndex]->clone();
228
229 if (reorderedArg.fComponents.empty()) {
230 newArgs.push_back(std::move(newArg));
231 } else {
232 newArgs.push_back(Swizzle::Make(context, pos, std::move(newArg),
233 reorderedArg.fComponents));
234 }
235 }
236
237 // Wrap the new argument list in a compound constructor.
238 return ConstructorCompound::Make(context,
239 pos,
240 componentType.toCompound(context, swizzleSize, /*rows=*/1),
241 std::move(newArgs));
242 }
243
Convert(const Context & context,Position pos,Position maskPos,std::unique_ptr<Expression> base,std::string_view maskString)244 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
245 Position pos,
246 Position maskPos,
247 std::unique_ptr<Expression> base,
248 std::string_view maskString) {
249 ComponentArray components;
250 for (size_t i = 0; i < maskString.length(); ++i) {
251 char field = maskString[i];
252 switch (field) {
253 case '0': components.push_back(SwizzleComponent::ZERO); break;
254 case '1': components.push_back(SwizzleComponent::ONE); break;
255 case 'x': components.push_back(SwizzleComponent::X); break;
256 case 'r': components.push_back(SwizzleComponent::R); break;
257 case 's': components.push_back(SwizzleComponent::S); break;
258 case 'L': components.push_back(SwizzleComponent::UL); break;
259 case 'y': components.push_back(SwizzleComponent::Y); break;
260 case 'g': components.push_back(SwizzleComponent::G); break;
261 case 't': components.push_back(SwizzleComponent::T); break;
262 case 'T': components.push_back(SwizzleComponent::UT); break;
263 case 'z': components.push_back(SwizzleComponent::Z); break;
264 case 'b': components.push_back(SwizzleComponent::B); break;
265 case 'p': components.push_back(SwizzleComponent::P); break;
266 case 'R': components.push_back(SwizzleComponent::UR); break;
267 case 'w': components.push_back(SwizzleComponent::W); break;
268 case 'a': components.push_back(SwizzleComponent::A); break;
269 case 'q': components.push_back(SwizzleComponent::Q); break;
270 case 'B': components.push_back(SwizzleComponent::UB); break;
271 default:
272 context.fErrors->error(Position::Range(maskPos.startOffset() + i,
273 maskPos.startOffset() + i + 1),
274 String::printf("invalid swizzle component '%c'", field));
275 return nullptr;
276 }
277 }
278 return Convert(context, pos, maskPos, std::move(base), std::move(components));
279 }
280
281 // Swizzles are complicated due to constant components. The most difficult case is a mask like
282 // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that evaluates
283 // 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0') together and use a
284 // secondary swizzle to put them back into the right order, so in this case we end up with
285 // 'float4(base.xw, 1, 0).xzyw'.
Convert(const Context & context,Position pos,Position rawMaskPos,std::unique_ptr<Expression> base,ComponentArray inComponents)286 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
287 Position pos,
288 Position rawMaskPos,
289 std::unique_ptr<Expression> base,
290 ComponentArray inComponents) {
291 Position maskPos = rawMaskPos.valid() ? rawMaskPos : pos;
292 if (!validate_swizzle_domain(inComponents)) {
293 context.fErrors->error(maskPos, "invalid swizzle mask '" + mask_string(inComponents) + "'");
294 return nullptr;
295 }
296
297 const Type& baseType = base->type().scalarTypeForLiteral();
298
299 if (!baseType.isVector() && !baseType.isScalar()) {
300 context.fErrors->error(
301 pos, "cannot swizzle value of type '" + baseType.displayName() + "'");
302 return nullptr;
303 }
304
305 if (inComponents.size() > 4) {
306 Position errorPos = rawMaskPos.valid() ? Position::Range(maskPos.startOffset() + 4,
307 maskPos.endOffset())
308 : pos;
309 context.fErrors->error(errorPos,
310 "too many components in swizzle mask '" + mask_string(inComponents) + "'");
311 return nullptr;
312 }
313
314 ComponentArray maskComponents;
315 bool foundXYZW = false;
316 for (int i = 0; i < inComponents.size(); ++i) {
317 switch (inComponents[i]) {
318 case SwizzleComponent::ZERO:
319 case SwizzleComponent::ONE:
320 // Skip over constant fields for now.
321 break;
322 case SwizzleComponent::X:
323 case SwizzleComponent::R:
324 case SwizzleComponent::S:
325 case SwizzleComponent::UL:
326 foundXYZW = true;
327 maskComponents.push_back(SwizzleComponent::X);
328 break;
329 case SwizzleComponent::Y:
330 case SwizzleComponent::G:
331 case SwizzleComponent::T:
332 case SwizzleComponent::UT:
333 foundXYZW = true;
334 if (baseType.columns() >= 2) {
335 maskComponents.push_back(SwizzleComponent::Y);
336 break;
337 }
338 [[fallthrough]];
339 case SwizzleComponent::Z:
340 case SwizzleComponent::B:
341 case SwizzleComponent::P:
342 case SwizzleComponent::UR:
343 foundXYZW = true;
344 if (baseType.columns() >= 3) {
345 maskComponents.push_back(SwizzleComponent::Z);
346 break;
347 }
348 [[fallthrough]];
349 case SwizzleComponent::W:
350 case SwizzleComponent::A:
351 case SwizzleComponent::Q:
352 case SwizzleComponent::UB:
353 foundXYZW = true;
354 if (baseType.columns() >= 4) {
355 maskComponents.push_back(SwizzleComponent::W);
356 break;
357 }
358 [[fallthrough]];
359 default:
360 // The swizzle component references a field that doesn't exist in the base type.
361 context.fErrors->error(
362 Position::Range(maskPos.startOffset() + i,
363 maskPos.startOffset() + i + 1),
364 String::printf("invalid swizzle component '%c'",
365 mask_char(inComponents[i])));
366 return nullptr;
367 }
368 }
369
370 if (!foundXYZW) {
371 context.fErrors->error(maskPos, "swizzle must refer to base expression");
372 return nullptr;
373 }
374
375 // Coerce literals in expressions such as `(12345).xxx` to their actual type.
376 base = baseType.coerceExpression(std::move(base), context);
377 if (!base) {
378 return nullptr;
379 }
380
381 // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
382 // scalar.xxx -> type3(scalar)
383 // scalar.x0x0 -> type2(scalar)
384 // vector.zyx -> vector.zyx
385 // vector.x0y0 -> vector.xy
386 std::unique_ptr<Expression> expr = Swizzle::Make(context, pos, std::move(base), maskComponents);
387
388 // If we have processed the entire swizzle, we're done.
389 if (maskComponents.size() == inComponents.size()) {
390 return expr;
391 }
392
393 // Now we create a constructor that has the correct number of elements for the final swizzle,
394 // with all fields at the start. It's not finished yet; constants we need will be added below.
395 // scalar.x0x0 -> type4(type2(x), ...)
396 // vector.y111 -> type4(vector.y, ...)
397 // vector.z10x -> type4(vector.zx, ...)
398 //
399 // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
400 ExpressionArray constructorArgs;
401 constructorArgs.reserve_back(3);
402 constructorArgs.push_back(std::move(expr));
403
404 // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
405 // need are also tacked on to the end of the constructor.
406 // scalar.x0x0 -> type4(type2(x), 0).xyxy
407 // vector.y111 -> type2(vector.y, 1).xyyy
408 // vector.z10x -> type4(vector.zx, 1, 0).xzwy
409 const Type* scalarType = &baseType.componentType();
410 ComponentArray swizzleComponents;
411 int maskFieldIdx = 0;
412 int constantFieldIdx = maskComponents.size();
413 int constantZeroIdx = -1, constantOneIdx = -1;
414
415 for (int i = 0; i < inComponents.size(); i++) {
416 switch (inComponents[i]) {
417 case SwizzleComponent::ZERO:
418 if (constantZeroIdx == -1) {
419 // Synthesize a '0' argument at the end of the constructor.
420 constructorArgs.push_back(Literal::Make(pos, /*value=*/0, scalarType));
421 constantZeroIdx = constantFieldIdx++;
422 }
423 swizzleComponents.push_back(constantZeroIdx);
424 break;
425 case SwizzleComponent::ONE:
426 if (constantOneIdx == -1) {
427 // Synthesize a '1' argument at the end of the constructor.
428 constructorArgs.push_back(Literal::Make(pos, /*value=*/1, scalarType));
429 constantOneIdx = constantFieldIdx++;
430 }
431 swizzleComponents.push_back(constantOneIdx);
432 break;
433 default:
434 // The non-constant fields are already in the expected order.
435 swizzleComponents.push_back(maskFieldIdx++);
436 break;
437 }
438 }
439
440 expr = ConstructorCompound::Make(context, pos,
441 scalarType->toCompound(context, constantFieldIdx, /*rows=*/1),
442 std::move(constructorArgs));
443
444 // Create (and potentially optimize-away) the resulting swizzle-expression.
445 return Swizzle::Make(context, pos, std::move(expr), swizzleComponents);
446 }
447
Make(const Context & context,Position pos,std::unique_ptr<Expression> expr,ComponentArray components)448 std::unique_ptr<Expression> Swizzle::Make(const Context& context,
449 Position pos,
450 std::unique_ptr<Expression> expr,
451 ComponentArray components) {
452 const Type& exprType = expr->type();
453 SkASSERTF(exprType.isVector() || exprType.isScalar(),
454 "cannot swizzle type '%s'", exprType.description().c_str());
455 SkASSERT(components.size() >= 1 && components.size() <= 4);
456
457 // Confirm that the component array only contains X/Y/Z/W. (Call MakeWith01 if you want support
458 // for ZERO and ONE. Once initial IR generation is complete, no swizzles should have zeros or
459 // ones in them.)
460 SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
461 return component >= SwizzleComponent::X &&
462 component <= SwizzleComponent::W;
463 }));
464
465 // SkSL supports splatting a scalar via `scalar.xxxx`, but not all versions of GLSL allow this.
466 // Replace swizzles with equivalent splat constructors (`scalar.xxx` --> `half3(value)`).
467 if (exprType.isScalar()) {
468 return ConstructorSplat::Make(context, pos,
469 exprType.toCompound(context, components.size(), /*rows=*/1),
470 std::move(expr));
471 }
472
473 // Detect identity swizzles like `color.rgba` and optimize it away.
474 if (components.size() == exprType.columns()) {
475 bool identity = true;
476 for (int i = 0; i < components.size(); ++i) {
477 if (components[i] != i) {
478 identity = false;
479 break;
480 }
481 }
482 if (identity) {
483 expr->fPosition = pos;
484 return expr;
485 }
486 }
487
488 // Optimize swizzles of swizzles, e.g. replace `foo.argb.rggg` with `foo.arrr`.
489 if (expr->is<Swizzle>()) {
490 Swizzle& base = expr->as<Swizzle>();
491 ComponentArray combined;
492 for (int8_t c : components) {
493 combined.push_back(base.components()[c]);
494 }
495
496 // It may actually be possible to further simplify this swizzle. Go again.
497 // (e.g. `color.abgr.abgr` --> `color.rgba` --> `color`.)
498 return Swizzle::Make(context, pos, std::move(base.base()), combined);
499 }
500
501 // If we are swizzling a constant expression, we can use its value instead here (so that
502 // swizzles like `colorWhite.x` can be simplified to `1`).
503 const Expression* value = ConstantFolder::GetConstantValueForVariable(*expr);
504
505 // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
506 // optimized to just `scalar`. The swizzle components don't actually matter, as every field
507 // in a splat constructor holds the same value.
508 if (value->is<ConstructorSplat>()) {
509 const ConstructorSplat& splat = value->as<ConstructorSplat>();
510 return ConstructorSplat::Make(
511 context, pos,
512 splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
513 splat.argument()->clone());
514 }
515
516 // Swizzles on casts, like `half4(myFloat4).zyy`, can optimize to `half3(myFloat4.zyy)`.
517 if (value->is<ConstructorCompoundCast>()) {
518 const ConstructorCompoundCast& cast = value->as<ConstructorCompoundCast>();
519 const Type& castType = cast.type().componentType().toCompound(context, components.size(),
520 /*rows=*/1);
521 std::unique_ptr<Expression> swizzled = Swizzle::Make(context, pos, cast.argument()->clone(),
522 std::move(components));
523 return (castType.columns() > 1)
524 ? ConstructorCompoundCast::Make(context, pos, castType, std::move(swizzled))
525 : ConstructorScalarCast::Make(context, pos, castType, std::move(swizzled));
526 }
527
528 // Swizzles on compound constructors, like `half4(1, 2, 3, 4).yw`, can become `half2(2, 4)`.
529 if (value->is<ConstructorCompound>()) {
530 const ConstructorCompound& ctor = value->as<ConstructorCompound>();
531 if (auto replacement = optimize_constructor_swizzle(context, pos, ctor, components)) {
532 return replacement;
533 }
534 }
535
536 // The swizzle could not be simplified, so apply the requested swizzle to the base expression.
537 return std::make_unique<Swizzle>(context, pos, std::move(expr), components);
538 }
539
description(OperatorPrecedence) const540 std::string Swizzle::description(OperatorPrecedence) const {
541 std::string result = this->base()->description(OperatorPrecedence::kPostfix) + ".";
542 for (int x : this->components()) {
543 result += "xyzw"[x];
544 }
545 return result;
546 }
547
548 } // namespace SkSL
549