• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwizzle.h"
9 
10 #include "include/core/SkSpan.h"
11 #include "include/private/SkSLString.h"
12 #include "include/private/base/SkTArray.h"
13 #include "include/sksl/SkSLErrorReporter.h"
14 #include "include/sksl/SkSLOperator.h"
15 #include "src/sksl/SkSLAnalysis.h"
16 #include "src/sksl/SkSLConstantFolder.h"
17 #include "src/sksl/SkSLContext.h"
18 #include "src/sksl/ir/SkSLConstructorCompound.h"
19 #include "src/sksl/ir/SkSLConstructorCompoundCast.h"
20 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
21 #include "src/sksl/ir/SkSLConstructorSplat.h"
22 #include "src/sksl/ir/SkSLLiteral.h"
23 
24 #include <algorithm>
25 #include <cstddef>
26 #include <cstdint>
27 #include <optional>
28 
29 namespace SkSL {
30 
validate_swizzle_domain(const ComponentArray & fields)31 static bool validate_swizzle_domain(const ComponentArray& fields) {
32     enum SwizzleDomain {
33         kCoordinate,
34         kColor,
35         kUV,
36         kRectangle,
37     };
38 
39     std::optional<SwizzleDomain> domain;
40 
41     for (int8_t field : fields) {
42         SwizzleDomain fieldDomain;
43         switch (field) {
44             case SwizzleComponent::X:
45             case SwizzleComponent::Y:
46             case SwizzleComponent::Z:
47             case SwizzleComponent::W:
48                 fieldDomain = kCoordinate;
49                 break;
50             case SwizzleComponent::R:
51             case SwizzleComponent::G:
52             case SwizzleComponent::B:
53             case SwizzleComponent::A:
54                 fieldDomain = kColor;
55                 break;
56             case SwizzleComponent::S:
57             case SwizzleComponent::T:
58             case SwizzleComponent::P:
59             case SwizzleComponent::Q:
60                 fieldDomain = kUV;
61                 break;
62             case SwizzleComponent::UL:
63             case SwizzleComponent::UT:
64             case SwizzleComponent::UR:
65             case SwizzleComponent::UB:
66                 fieldDomain = kRectangle;
67                 break;
68             case SwizzleComponent::ZERO:
69             case SwizzleComponent::ONE:
70                 continue;
71             default:
72                 return false;
73         }
74 
75         if (!domain.has_value()) {
76             domain = fieldDomain;
77         } else if (domain != fieldDomain) {
78             return false;
79         }
80     }
81 
82     return true;
83 }
84 
mask_char(int8_t component)85 static char mask_char(int8_t component) {
86     switch (component) {
87         case SwizzleComponent::X:    return 'x';
88         case SwizzleComponent::Y:    return 'y';
89         case SwizzleComponent::Z:    return 'z';
90         case SwizzleComponent::W:    return 'w';
91         case SwizzleComponent::R:    return 'r';
92         case SwizzleComponent::G:    return 'g';
93         case SwizzleComponent::B:    return 'b';
94         case SwizzleComponent::A:    return 'a';
95         case SwizzleComponent::S:    return 's';
96         case SwizzleComponent::T:    return 't';
97         case SwizzleComponent::P:    return 'p';
98         case SwizzleComponent::Q:    return 'q';
99         case SwizzleComponent::UL:   return 'L';
100         case SwizzleComponent::UT:   return 'T';
101         case SwizzleComponent::UR:   return 'R';
102         case SwizzleComponent::UB:   return 'B';
103         case SwizzleComponent::ZERO: return '0';
104         case SwizzleComponent::ONE:  return '1';
105         default: SkUNREACHABLE;
106     }
107 }
108 
mask_string(const ComponentArray & components)109 static std::string mask_string(const ComponentArray& components) {
110     std::string result;
111     for (int8_t component : components) {
112         result += mask_char(component);
113     }
114     return result;
115 }
116 
optimize_constructor_swizzle(const Context & context,Position pos,const ConstructorCompound & base,ComponentArray components)117 static std::unique_ptr<Expression> optimize_constructor_swizzle(const Context& context,
118                                                                 Position pos,
119                                                                 const ConstructorCompound& base,
120                                                                 ComponentArray components) {
121     auto baseArguments = base.argumentSpan();
122     std::unique_ptr<Expression> replacement;
123     const Type& exprType = base.type();
124     const Type& componentType = exprType.componentType();
125     int swizzleSize = components.size();
126 
127     // Swizzles can duplicate some elements and discard others, e.g.
128     // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
129     // - Expressions with side effects need to occur exactly once, even if they would otherwise be
130     //   swizzle-eliminated
131     // - Non-trivial expressions should not be repeated, but elimination is OK.
132     //
133     // Look up the argument for the constructor at each index. This is typically simple but for
134     // weird cases like `half4(bar.yz, half2(foo))`, it can be harder than it seems. This example
135     // would result in:
136     //     argMap[0] = {.fArgIndex = 0, .fComponent = 0}   (bar.yz     .x)
137     //     argMap[1] = {.fArgIndex = 0, .fComponent = 1}   (bar.yz     .y)
138     //     argMap[2] = {.fArgIndex = 1, .fComponent = 0}   (half2(foo) .x)
139     //     argMap[3] = {.fArgIndex = 1, .fComponent = 1}   (half2(foo) .y)
140     struct ConstructorArgMap {
141         int8_t fArgIndex;
142         int8_t fComponent;
143     };
144 
145     int numConstructorArgs = base.type().columns();
146     ConstructorArgMap argMap[4] = {};
147     int writeIdx = 0;
148     for (int argIdx = 0; argIdx < (int)baseArguments.size(); ++argIdx) {
149         const Expression& arg = *baseArguments[argIdx];
150         const Type& argType = arg.type();
151 
152         if (!argType.isScalar() && !argType.isVector()) {
153             return nullptr;
154         }
155 
156         int argSlots = argType.slotCount();
157         for (int componentIdx = 0; componentIdx < argSlots; ++componentIdx) {
158             argMap[writeIdx].fArgIndex = argIdx;
159             argMap[writeIdx].fComponent = componentIdx;
160             ++writeIdx;
161         }
162     }
163     SkASSERT(writeIdx == numConstructorArgs);
164 
165     // Count up the number of times each constructor argument is used by the swizzle.
166     //    `half4(bar.yz, half2(foo)).xwxy` -> { 3, 1 }
167     // - bar.yz    is referenced 3 times, by `.x_xy`
168     // - half(foo) is referenced 1 time,  by `._w__`
169     int8_t exprUsed[4] = {};
170     for (int8_t c : components) {
171         exprUsed[argMap[c].fArgIndex]++;
172     }
173 
174     for (int index = 0; index < numConstructorArgs; ++index) {
175         int8_t constructorArgIndex = argMap[index].fArgIndex;
176         const Expression& baseArg = *baseArguments[constructorArgIndex];
177 
178         // Check that non-trivial expressions are not swizzled in more than once.
179         if (exprUsed[constructorArgIndex] > 1 && !Analysis::IsTrivialExpression(baseArg)) {
180             return nullptr;
181         }
182         // Check that side-effect-bearing expressions are swizzled in exactly once.
183         if (exprUsed[constructorArgIndex] != 1 && Analysis::HasSideEffects(baseArg)) {
184             return nullptr;
185         }
186     }
187 
188     struct ReorderedArgument {
189         int8_t fArgIndex;
190         ComponentArray fComponents;
191     };
192     SkSTArray<4, ReorderedArgument> reorderedArgs;
193     for (int8_t c : components) {
194         const ConstructorArgMap& argument = argMap[c];
195         const Expression& baseArg = *baseArguments[argument.fArgIndex];
196 
197         if (baseArg.type().isScalar()) {
198             // This argument is a scalar; add it to the list as-is.
199             SkASSERT(argument.fComponent == 0);
200             reorderedArgs.push_back({argument.fArgIndex,
201                                      ComponentArray{}});
202         } else {
203             // This argument is a component from a vector.
204             SkASSERT(baseArg.type().isVector());
205             SkASSERT(argument.fComponent < baseArg.type().columns());
206             if (reorderedArgs.empty() ||
207                 reorderedArgs.back().fArgIndex != argument.fArgIndex) {
208                 // This can't be combined with the previous argument. Add a new one.
209                 reorderedArgs.push_back({argument.fArgIndex,
210                                          ComponentArray{argument.fComponent}});
211             } else {
212                 // Since we know this argument uses components, it should already have at least one
213                 // component set.
214                 SkASSERT(!reorderedArgs.back().fComponents.empty());
215                 // Build up the current argument with one more component.
216                 reorderedArgs.back().fComponents.push_back(argument.fComponent);
217             }
218         }
219     }
220 
221     // Convert our reordered argument list to an actual array of expressions, with the new order and
222     // any new inner swizzles that need to be applied.
223     ExpressionArray newArgs;
224     newArgs.reserve_back(swizzleSize);
225     for (const ReorderedArgument& reorderedArg : reorderedArgs) {
226         std::unique_ptr<Expression> newArg =
227                 baseArguments[reorderedArg.fArgIndex]->clone();
228 
229         if (reorderedArg.fComponents.empty()) {
230             newArgs.push_back(std::move(newArg));
231         } else {
232             newArgs.push_back(Swizzle::Make(context, pos, std::move(newArg),
233                                             reorderedArg.fComponents));
234         }
235     }
236 
237     // Wrap the new argument list in a compound constructor.
238     return ConstructorCompound::Make(context,
239                                      pos,
240                                      componentType.toCompound(context, swizzleSize, /*rows=*/1),
241                                      std::move(newArgs));
242 }
243 
Convert(const Context & context,Position pos,Position maskPos,std::unique_ptr<Expression> base,std::string_view maskString)244 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
245                                              Position pos,
246                                              Position maskPos,
247                                              std::unique_ptr<Expression> base,
248                                              std::string_view maskString) {
249     ComponentArray components;
250     for (size_t i = 0; i < maskString.length(); ++i) {
251         char field = maskString[i];
252         switch (field) {
253             case '0': components.push_back(SwizzleComponent::ZERO); break;
254             case '1': components.push_back(SwizzleComponent::ONE);  break;
255             case 'x': components.push_back(SwizzleComponent::X);    break;
256             case 'r': components.push_back(SwizzleComponent::R);    break;
257             case 's': components.push_back(SwizzleComponent::S);    break;
258             case 'L': components.push_back(SwizzleComponent::UL);   break;
259             case 'y': components.push_back(SwizzleComponent::Y);    break;
260             case 'g': components.push_back(SwizzleComponent::G);    break;
261             case 't': components.push_back(SwizzleComponent::T);    break;
262             case 'T': components.push_back(SwizzleComponent::UT);   break;
263             case 'z': components.push_back(SwizzleComponent::Z);    break;
264             case 'b': components.push_back(SwizzleComponent::B);    break;
265             case 'p': components.push_back(SwizzleComponent::P);    break;
266             case 'R': components.push_back(SwizzleComponent::UR);   break;
267             case 'w': components.push_back(SwizzleComponent::W);    break;
268             case 'a': components.push_back(SwizzleComponent::A);    break;
269             case 'q': components.push_back(SwizzleComponent::Q);    break;
270             case 'B': components.push_back(SwizzleComponent::UB);   break;
271             default:
272                 context.fErrors->error(Position::Range(maskPos.startOffset() + i,
273                                                        maskPos.startOffset() + i + 1),
274                                        String::printf("invalid swizzle component '%c'", field));
275                 return nullptr;
276         }
277     }
278     return Convert(context, pos, maskPos, std::move(base), std::move(components));
279 }
280 
281 // Swizzles are complicated due to constant components. The most difficult case is a mask like
282 // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that evaluates
283 // 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0') together and use a
284 // secondary swizzle to put them back into the right order, so in this case we end up with
285 // 'float4(base.xw, 1, 0).xzyw'.
Convert(const Context & context,Position pos,Position rawMaskPos,std::unique_ptr<Expression> base,ComponentArray inComponents)286 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
287                                              Position pos,
288                                              Position rawMaskPos,
289                                              std::unique_ptr<Expression> base,
290                                              ComponentArray inComponents) {
291     Position maskPos = rawMaskPos.valid() ? rawMaskPos : pos;
292     if (!validate_swizzle_domain(inComponents)) {
293         context.fErrors->error(maskPos, "invalid swizzle mask '" + mask_string(inComponents) + "'");
294         return nullptr;
295     }
296 
297     const Type& baseType = base->type().scalarTypeForLiteral();
298 
299     if (!baseType.isVector() && !baseType.isScalar()) {
300         context.fErrors->error(
301                 pos, "cannot swizzle value of type '" + baseType.displayName() + "'");
302         return nullptr;
303     }
304 
305     if (inComponents.size() > 4) {
306         Position errorPos = rawMaskPos.valid() ? Position::Range(maskPos.startOffset() + 4,
307                                                                  maskPos.endOffset())
308                                                : pos;
309         context.fErrors->error(errorPos,
310                 "too many components in swizzle mask '" + mask_string(inComponents) + "'");
311         return nullptr;
312     }
313 
314     ComponentArray maskComponents;
315     bool foundXYZW = false;
316     for (int i = 0; i < inComponents.size(); ++i) {
317         switch (inComponents[i]) {
318             case SwizzleComponent::ZERO:
319             case SwizzleComponent::ONE:
320                 // Skip over constant fields for now.
321                 break;
322             case SwizzleComponent::X:
323             case SwizzleComponent::R:
324             case SwizzleComponent::S:
325             case SwizzleComponent::UL:
326                 foundXYZW = true;
327                 maskComponents.push_back(SwizzleComponent::X);
328                 break;
329             case SwizzleComponent::Y:
330             case SwizzleComponent::G:
331             case SwizzleComponent::T:
332             case SwizzleComponent::UT:
333                 foundXYZW = true;
334                 if (baseType.columns() >= 2) {
335                     maskComponents.push_back(SwizzleComponent::Y);
336                     break;
337                 }
338                 [[fallthrough]];
339             case SwizzleComponent::Z:
340             case SwizzleComponent::B:
341             case SwizzleComponent::P:
342             case SwizzleComponent::UR:
343                 foundXYZW = true;
344                 if (baseType.columns() >= 3) {
345                     maskComponents.push_back(SwizzleComponent::Z);
346                     break;
347                 }
348                 [[fallthrough]];
349             case SwizzleComponent::W:
350             case SwizzleComponent::A:
351             case SwizzleComponent::Q:
352             case SwizzleComponent::UB:
353                 foundXYZW = true;
354                 if (baseType.columns() >= 4) {
355                     maskComponents.push_back(SwizzleComponent::W);
356                     break;
357                 }
358                 [[fallthrough]];
359             default:
360                 // The swizzle component references a field that doesn't exist in the base type.
361                 context.fErrors->error(
362                         Position::Range(maskPos.startOffset() + i,
363                                         maskPos.startOffset() + i + 1),
364                         String::printf("invalid swizzle component '%c'",
365                                        mask_char(inComponents[i])));
366                 return nullptr;
367         }
368     }
369 
370     if (!foundXYZW) {
371         context.fErrors->error(maskPos, "swizzle must refer to base expression");
372         return nullptr;
373     }
374 
375     // Coerce literals in expressions such as `(12345).xxx` to their actual type.
376     base = baseType.coerceExpression(std::move(base), context);
377     if (!base) {
378         return nullptr;
379     }
380 
381     // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
382     //   scalar.xxx  -> type3(scalar)
383     //   scalar.x0x0 -> type2(scalar)
384     //   vector.zyx  -> vector.zyx
385     //   vector.x0y0 -> vector.xy
386     std::unique_ptr<Expression> expr = Swizzle::Make(context, pos, std::move(base), maskComponents);
387 
388     // If we have processed the entire swizzle, we're done.
389     if (maskComponents.size() == inComponents.size()) {
390         return expr;
391     }
392 
393     // Now we create a constructor that has the correct number of elements for the final swizzle,
394     // with all fields at the start. It's not finished yet; constants we need will be added below.
395     //   scalar.x0x0 -> type4(type2(x), ...)
396     //   vector.y111 -> type4(vector.y, ...)
397     //   vector.z10x -> type4(vector.zx, ...)
398     //
399     // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
400     ExpressionArray constructorArgs;
401     constructorArgs.reserve_back(3);
402     constructorArgs.push_back(std::move(expr));
403 
404     // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
405     // need are also tacked on to the end of the constructor.
406     //   scalar.x0x0 -> type4(type2(x), 0).xyxy
407     //   vector.y111 -> type2(vector.y, 1).xyyy
408     //   vector.z10x -> type4(vector.zx, 1, 0).xzwy
409     const Type* scalarType = &baseType.componentType();
410     ComponentArray swizzleComponents;
411     int maskFieldIdx = 0;
412     int constantFieldIdx = maskComponents.size();
413     int constantZeroIdx = -1, constantOneIdx = -1;
414 
415     for (int i = 0; i < inComponents.size(); i++) {
416         switch (inComponents[i]) {
417             case SwizzleComponent::ZERO:
418                 if (constantZeroIdx == -1) {
419                     // Synthesize a '0' argument at the end of the constructor.
420                     constructorArgs.push_back(Literal::Make(pos, /*value=*/0, scalarType));
421                     constantZeroIdx = constantFieldIdx++;
422                 }
423                 swizzleComponents.push_back(constantZeroIdx);
424                 break;
425             case SwizzleComponent::ONE:
426                 if (constantOneIdx == -1) {
427                     // Synthesize a '1' argument at the end of the constructor.
428                     constructorArgs.push_back(Literal::Make(pos, /*value=*/1, scalarType));
429                     constantOneIdx = constantFieldIdx++;
430                 }
431                 swizzleComponents.push_back(constantOneIdx);
432                 break;
433             default:
434                 // The non-constant fields are already in the expected order.
435                 swizzleComponents.push_back(maskFieldIdx++);
436                 break;
437         }
438     }
439 
440     expr = ConstructorCompound::Make(context, pos,
441                                      scalarType->toCompound(context, constantFieldIdx, /*rows=*/1),
442                                      std::move(constructorArgs));
443 
444     // Create (and potentially optimize-away) the resulting swizzle-expression.
445     return Swizzle::Make(context, pos, std::move(expr), swizzleComponents);
446 }
447 
Make(const Context & context,Position pos,std::unique_ptr<Expression> expr,ComponentArray components)448 std::unique_ptr<Expression> Swizzle::Make(const Context& context,
449                                           Position pos,
450                                           std::unique_ptr<Expression> expr,
451                                           ComponentArray components) {
452     const Type& exprType = expr->type();
453     SkASSERTF(exprType.isVector() || exprType.isScalar(),
454               "cannot swizzle type '%s'", exprType.description().c_str());
455     SkASSERT(components.size() >= 1 && components.size() <= 4);
456 
457     // Confirm that the component array only contains X/Y/Z/W. (Call MakeWith01 if you want support
458     // for ZERO and ONE. Once initial IR generation is complete, no swizzles should have zeros or
459     // ones in them.)
460     SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
461         return component >= SwizzleComponent::X &&
462                component <= SwizzleComponent::W;
463     }));
464 
465     // SkSL supports splatting a scalar via `scalar.xxxx`, but not all versions of GLSL allow this.
466     // Replace swizzles with equivalent splat constructors (`scalar.xxx` --> `half3(value)`).
467     if (exprType.isScalar()) {
468         return ConstructorSplat::Make(context, pos,
469                                       exprType.toCompound(context, components.size(), /*rows=*/1),
470                                       std::move(expr));
471     }
472 
473     // Detect identity swizzles like `color.rgba` and optimize it away.
474     if (components.size() == exprType.columns()) {
475         bool identity = true;
476         for (int i = 0; i < components.size(); ++i) {
477             if (components[i] != i) {
478                 identity = false;
479                 break;
480             }
481         }
482         if (identity) {
483             expr->fPosition = pos;
484             return expr;
485         }
486     }
487 
488     // Optimize swizzles of swizzles, e.g. replace `foo.argb.rggg` with `foo.arrr`.
489     if (expr->is<Swizzle>()) {
490         Swizzle& base = expr->as<Swizzle>();
491         ComponentArray combined;
492         for (int8_t c : components) {
493             combined.push_back(base.components()[c]);
494         }
495 
496         // It may actually be possible to further simplify this swizzle. Go again.
497         // (e.g. `color.abgr.abgr` --> `color.rgba` --> `color`.)
498         return Swizzle::Make(context, pos, std::move(base.base()), combined);
499     }
500 
501     // If we are swizzling a constant expression, we can use its value instead here (so that
502     // swizzles like `colorWhite.x` can be simplified to `1`).
503     const Expression* value = ConstantFolder::GetConstantValueForVariable(*expr);
504 
505     // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
506     // optimized to just `scalar`. The swizzle components don't actually matter, as every field
507     // in a splat constructor holds the same value.
508     if (value->is<ConstructorSplat>()) {
509         const ConstructorSplat& splat = value->as<ConstructorSplat>();
510         return ConstructorSplat::Make(
511                 context, pos,
512                 splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
513                 splat.argument()->clone());
514     }
515 
516     // Swizzles on casts, like `half4(myFloat4).zyy`, can optimize to `half3(myFloat4.zyy)`.
517     if (value->is<ConstructorCompoundCast>()) {
518         const ConstructorCompoundCast& cast = value->as<ConstructorCompoundCast>();
519         const Type& castType = cast.type().componentType().toCompound(context, components.size(),
520                                                                       /*rows=*/1);
521         std::unique_ptr<Expression> swizzled = Swizzle::Make(context, pos, cast.argument()->clone(),
522                                                              std::move(components));
523         return (castType.columns() > 1)
524                        ? ConstructorCompoundCast::Make(context, pos, castType, std::move(swizzled))
525                        : ConstructorScalarCast::Make(context, pos, castType, std::move(swizzled));
526     }
527 
528     // Swizzles on compound constructors, like `half4(1, 2, 3, 4).yw`, can become `half2(2, 4)`.
529     if (value->is<ConstructorCompound>()) {
530         const ConstructorCompound& ctor = value->as<ConstructorCompound>();
531         if (auto replacement = optimize_constructor_swizzle(context, pos, ctor, components)) {
532             return replacement;
533         }
534     }
535 
536     // The swizzle could not be simplified, so apply the requested swizzle to the base expression.
537     return std::make_unique<Swizzle>(context, pos, std::move(expr), components);
538 }
539 
description(OperatorPrecedence) const540 std::string Swizzle::description(OperatorPrecedence) const {
541     std::string result = this->base()->description(OperatorPrecedence::kPostfix) + ".";
542     for (int x : this->components()) {
543         result += "xyzw"[x];
544     }
545     return result;
546 }
547 
548 }  // namespace SkSL
549