• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Use of this source code is governed by a BSD-style license that can be
5  * found in the LICENSE file.
6  */
7 
8 #include "src/sksl/ir/SkSLSwizzle.h"
9 
10 #include "include/private/SkTOptional.h"
11 #include "include/sksl/SkSLErrorReporter.h"
12 #include "src/sksl/SkSLAnalysis.h"
13 #include "src/sksl/SkSLConstantFolder.h"
14 #include "src/sksl/SkSLProgramSettings.h"
15 #include "src/sksl/ir/SkSLConstructor.h"
16 #include "src/sksl/ir/SkSLConstructorScalarCast.h"
17 #include "src/sksl/ir/SkSLConstructorSplat.h"
18 #include "src/sksl/ir/SkSLLiteral.h"
19 
20 namespace SkSL {
21 
validate_swizzle_domain(const ComponentArray & fields)22 static bool validate_swizzle_domain(const ComponentArray& fields) {
23     enum SwizzleDomain {
24         kCoordinate,
25         kColor,
26         kUV,
27         kRectangle,
28     };
29 
30     skstd::optional<SwizzleDomain> domain;
31 
32     for (int8_t field : fields) {
33         SwizzleDomain fieldDomain;
34         switch (field) {
35             case SwizzleComponent::X:
36             case SwizzleComponent::Y:
37             case SwizzleComponent::Z:
38             case SwizzleComponent::W:
39                 fieldDomain = kCoordinate;
40                 break;
41             case SwizzleComponent::R:
42             case SwizzleComponent::G:
43             case SwizzleComponent::B:
44             case SwizzleComponent::A:
45                 fieldDomain = kColor;
46                 break;
47             case SwizzleComponent::S:
48             case SwizzleComponent::T:
49             case SwizzleComponent::P:
50             case SwizzleComponent::Q:
51                 fieldDomain = kUV;
52                 break;
53             case SwizzleComponent::UL:
54             case SwizzleComponent::UT:
55             case SwizzleComponent::UR:
56             case SwizzleComponent::UB:
57                 fieldDomain = kRectangle;
58                 break;
59             case SwizzleComponent::ZERO:
60             case SwizzleComponent::ONE:
61                 continue;
62             default:
63                 return false;
64         }
65 
66         if (!domain.has_value()) {
67             domain = fieldDomain;
68         } else if (domain != fieldDomain) {
69             return false;
70         }
71     }
72 
73     return true;
74 }
75 
mask_char(int8_t component)76 static char mask_char(int8_t component) {
77     switch (component) {
78         case SwizzleComponent::X:    return 'x';
79         case SwizzleComponent::Y:    return 'y';
80         case SwizzleComponent::Z:    return 'z';
81         case SwizzleComponent::W:    return 'w';
82         case SwizzleComponent::R:    return 'r';
83         case SwizzleComponent::G:    return 'g';
84         case SwizzleComponent::B:    return 'b';
85         case SwizzleComponent::A:    return 'a';
86         case SwizzleComponent::S:    return 's';
87         case SwizzleComponent::T:    return 't';
88         case SwizzleComponent::P:    return 'p';
89         case SwizzleComponent::Q:    return 'q';
90         case SwizzleComponent::UL:   return 'L';
91         case SwizzleComponent::UT:   return 'T';
92         case SwizzleComponent::UR:   return 'R';
93         case SwizzleComponent::UB:   return 'B';
94         case SwizzleComponent::ZERO: return '0';
95         case SwizzleComponent::ONE:  return '1';
96         default: SkUNREACHABLE;
97     }
98 }
99 
mask_string(const ComponentArray & components)100 static String mask_string(const ComponentArray& components) {
101     String result;
102     for (int8_t component : components) {
103         result += mask_char(component);
104     }
105     return result;
106 }
107 
optimize_constructor_swizzle(const Context & context,const AnyConstructor & base,ComponentArray components)108 static std::unique_ptr<Expression> optimize_constructor_swizzle(const Context& context,
109                                                                 const AnyConstructor& base,
110                                                                 ComponentArray components) {
111     auto baseArguments = base.argumentSpan();
112     std::unique_ptr<Expression> replacement;
113     const Type& exprType = base.type();
114     const Type& componentType = exprType.componentType();
115     int swizzleSize = components.size();
116 
117     // Swizzles can duplicate some elements and discard others, e.g.
118     // `half4(1, 2, 3, 4).xxz` --> `half3(1, 1, 3)`. However, there are constraints:
119     // - Expressions with side effects need to occur exactly once, even if they would otherwise be
120     //   swizzle-eliminated
121     // - Non-trivial expressions should not be repeated, but elimination is OK.
122     //
123     // Look up the argument for the constructor at each index. This is typically simple but for
124     // weird cases like `half4(bar.yz, half2(foo))`, it can be harder than it seems. This example
125     // would result in:
126     //     argMap[0] = {.fArgIndex = 0, .fComponent = 0}   (bar.yz     .x)
127     //     argMap[1] = {.fArgIndex = 0, .fComponent = 1}   (bar.yz     .y)
128     //     argMap[2] = {.fArgIndex = 1, .fComponent = 0}   (half2(foo) .x)
129     //     argMap[3] = {.fArgIndex = 1, .fComponent = 1}   (half2(foo) .y)
130     struct ConstructorArgMap {
131         int8_t fArgIndex;
132         int8_t fComponent;
133     };
134 
135     int numConstructorArgs = base.type().columns();
136     ConstructorArgMap argMap[4] = {};
137     int writeIdx = 0;
138     for (int argIdx = 0; argIdx < (int)baseArguments.size(); ++argIdx) {
139         const Expression& arg = *baseArguments[argIdx];
140         const Type& argType = arg.type();
141 
142         if (!argType.isScalar() && !argType.isVector()) {
143             return nullptr;
144         }
145 
146         int argSlots = argType.slotCount();
147         for (int componentIdx = 0; componentIdx < argSlots; ++componentIdx) {
148             argMap[writeIdx].fArgIndex = argIdx;
149             argMap[writeIdx].fComponent = componentIdx;
150             ++writeIdx;
151         }
152     }
153     SkASSERT(writeIdx == numConstructorArgs);
154 
155     // Count up the number of times each constructor argument is used by the swizzle.
156     //    `half4(bar.yz, half2(foo)).xwxy` -> { 3, 1 }
157     // - bar.yz    is referenced 3 times, by `.x_xy`
158     // - half(foo) is referenced 1 time,  by `._w__`
159     int8_t exprUsed[4] = {};
160     for (int8_t c : components) {
161         exprUsed[argMap[c].fArgIndex]++;
162     }
163 
164     for (int index = 0; index < numConstructorArgs; ++index) {
165         int8_t constructorArgIndex = argMap[index].fArgIndex;
166         const Expression& baseArg = *baseArguments[constructorArgIndex];
167 
168         // Check that non-trivial expressions are not swizzled in more than once.
169         if (exprUsed[constructorArgIndex] > 1 && !Analysis::IsTrivialExpression(baseArg)) {
170             return nullptr;
171         }
172         // Check that side-effect-bearing expressions are swizzled in exactly once.
173         if (exprUsed[constructorArgIndex] != 1 && baseArg.hasSideEffects()) {
174             return nullptr;
175         }
176     }
177 
178     struct ReorderedArgument {
179         int8_t fArgIndex;
180         ComponentArray fComponents;
181     };
182     SkSTArray<4, ReorderedArgument> reorderedArgs;
183     for (int8_t c : components) {
184         const ConstructorArgMap& argument = argMap[c];
185         const Expression& baseArg = *baseArguments[argument.fArgIndex];
186 
187         if (baseArg.type().isScalar()) {
188             // This argument is a scalar; add it to the list as-is.
189             SkASSERT(argument.fComponent == 0);
190             reorderedArgs.push_back({argument.fArgIndex,
191                                      ComponentArray{}});
192         } else {
193             // This argument is a component from a vector.
194             SkASSERT(baseArg.type().isVector());
195             SkASSERT(argument.fComponent < baseArg.type().columns());
196             if (reorderedArgs.empty() ||
197                 reorderedArgs.back().fArgIndex != argument.fArgIndex) {
198                 // This can't be combined with the previous argument. Add a new one.
199                 reorderedArgs.push_back({argument.fArgIndex,
200                                          ComponentArray{argument.fComponent}});
201             } else {
202                 // Since we know this argument uses components, it should already have at least one
203                 // component set.
204                 SkASSERT(!reorderedArgs.back().fComponents.empty());
205                 // Build up the current argument with one more component.
206                 reorderedArgs.back().fComponents.push_back(argument.fComponent);
207             }
208         }
209     }
210 
211     // Convert our reordered argument list to an actual array of expressions, with the new order and
212     // any new inner swizzles that need to be applied.
213     ExpressionArray newArgs;
214     newArgs.reserve_back(swizzleSize);
215     for (const ReorderedArgument& reorderedArg : reorderedArgs) {
216         std::unique_ptr<Expression> newArg =
217                 baseArguments[reorderedArg.fArgIndex]->clone();
218 
219         if (reorderedArg.fComponents.empty()) {
220             newArgs.push_back(std::move(newArg));
221         } else {
222             newArgs.push_back(Swizzle::Make(context, std::move(newArg),
223                                             reorderedArg.fComponents));
224         }
225     }
226 
227     // Wrap the new argument list in a constructor.
228     auto ctor = Constructor::Convert(context,
229                                      base.fLine,
230                                      componentType.toCompound(context, swizzleSize, /*rows=*/1),
231                                      std::move(newArgs));
232     SkASSERT(ctor);
233     return ctor;
234 }
235 
Convert(const Context & context,std::unique_ptr<Expression> base,skstd::string_view maskString)236 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
237                                              std::unique_ptr<Expression> base,
238                                              skstd::string_view maskString) {
239     ComponentArray components;
240     for (char field : maskString) {
241         switch (field) {
242             case '0': components.push_back(SwizzleComponent::ZERO); break;
243             case '1': components.push_back(SwizzleComponent::ONE);  break;
244             case 'x': components.push_back(SwizzleComponent::X);    break;
245             case 'r': components.push_back(SwizzleComponent::R);    break;
246             case 's': components.push_back(SwizzleComponent::S);    break;
247             case 'L': components.push_back(SwizzleComponent::UL);   break;
248             case 'y': components.push_back(SwizzleComponent::Y);    break;
249             case 'g': components.push_back(SwizzleComponent::G);    break;
250             case 't': components.push_back(SwizzleComponent::T);    break;
251             case 'T': components.push_back(SwizzleComponent::UT);   break;
252             case 'z': components.push_back(SwizzleComponent::Z);    break;
253             case 'b': components.push_back(SwizzleComponent::B);    break;
254             case 'p': components.push_back(SwizzleComponent::P);    break;
255             case 'R': components.push_back(SwizzleComponent::UR);   break;
256             case 'w': components.push_back(SwizzleComponent::W);    break;
257             case 'a': components.push_back(SwizzleComponent::A);    break;
258             case 'q': components.push_back(SwizzleComponent::Q);    break;
259             case 'B': components.push_back(SwizzleComponent::UB);   break;
260             default:
261                 context.fErrors->error(base->fLine,
262                         String::printf("invalid swizzle component '%c'", field));
263                 return nullptr;
264         }
265     }
266     return Convert(context, std::move(base), std::move(components));
267 }
268 
269 // Swizzles are complicated due to constant components. The most difficult case is a mask like
270 // '.x1w0'. A naive approach might turn that into 'float4(base.x, 1, base.w, 0)', but that evaluates
271 // 'base' twice. We instead group the swizzle mask ('xw') and constants ('1, 0') together and use a
272 // secondary swizzle to put them back into the right order, so in this case we end up with
273 // 'float4(base.xw, 1, 0).xzyw'.
Convert(const Context & context,std::unique_ptr<Expression> base,ComponentArray inComponents)274 std::unique_ptr<Expression> Swizzle::Convert(const Context& context,
275                                              std::unique_ptr<Expression> base,
276                                              ComponentArray inComponents) {
277     if (!validate_swizzle_domain(inComponents)) {
278         context.fErrors->error(base->fLine,
279                 "invalid swizzle mask '" + mask_string(inComponents) + "'");
280         return nullptr;
281     }
282 
283     const int line = base->fLine;
284     const Type& baseType = base->type();
285 
286     if (!baseType.isVector() && !baseType.isScalar()) {
287         context.fErrors->error(
288                 line, "cannot swizzle value of type '" + baseType.displayName() + "'");
289         return nullptr;
290     }
291 
292     if (inComponents.count() > 4) {
293         context.fErrors->error(line,
294                 "too many components in swizzle mask '" + mask_string(inComponents) + "'");
295         return nullptr;
296     }
297 
298     ComponentArray maskComponents;
299     bool foundXYZW = false;
300     for (int i = 0; i < inComponents.count(); ++i) {
301         switch (inComponents[i]) {
302             case SwizzleComponent::ZERO:
303             case SwizzleComponent::ONE:
304                 // Skip over constant fields for now.
305                 break;
306             case SwizzleComponent::X:
307             case SwizzleComponent::R:
308             case SwizzleComponent::S:
309             case SwizzleComponent::UL:
310                 foundXYZW = true;
311                 maskComponents.push_back(SwizzleComponent::X);
312                 break;
313             case SwizzleComponent::Y:
314             case SwizzleComponent::G:
315             case SwizzleComponent::T:
316             case SwizzleComponent::UT:
317                 foundXYZW = true;
318                 if (baseType.columns() >= 2) {
319                     maskComponents.push_back(SwizzleComponent::Y);
320                     break;
321                 }
322                 [[fallthrough]];
323             case SwizzleComponent::Z:
324             case SwizzleComponent::B:
325             case SwizzleComponent::P:
326             case SwizzleComponent::UR:
327                 foundXYZW = true;
328                 if (baseType.columns() >= 3) {
329                     maskComponents.push_back(SwizzleComponent::Z);
330                     break;
331                 }
332                 [[fallthrough]];
333             case SwizzleComponent::W:
334             case SwizzleComponent::A:
335             case SwizzleComponent::Q:
336             case SwizzleComponent::UB:
337                 foundXYZW = true;
338                 if (baseType.columns() >= 4) {
339                     maskComponents.push_back(SwizzleComponent::W);
340                     break;
341                 }
342                 [[fallthrough]];
343             default:
344                 // The swizzle component references a field that doesn't exist in the base type.
345                 context.fErrors->error(line,
346                        String::printf("invalid swizzle component '%c'",
347                             mask_char(inComponents[i])));
348                 return nullptr;
349         }
350     }
351 
352     if (!foundXYZW) {
353         context.fErrors->error(line, "swizzle must refer to base expression");
354         return nullptr;
355     }
356 
357     // Coerce literals in expressions such as `(12345).xxx` to their actual type.
358     base = baseType.scalarTypeForLiteral().coerceExpression(std::move(base), context);
359     if (!base) {
360         return nullptr;
361     }
362 
363     // First, we need a vector expression that is the non-constant portion of the swizzle, packed:
364     //   scalar.xxx  -> type3(scalar)
365     //   scalar.x0x0 -> type2(scalar)
366     //   vector.zyx  -> vector.zyx
367     //   vector.x0y0 -> vector.xy
368     std::unique_ptr<Expression> expr = Swizzle::Make(context, std::move(base), maskComponents);
369 
370     // If we have processed the entire swizzle, we're done.
371     if (maskComponents.count() == inComponents.count()) {
372         return expr;
373     }
374 
375     // Now we create a constructor that has the correct number of elements for the final swizzle,
376     // with all fields at the start. It's not finished yet; constants we need will be added below.
377     //   scalar.x0x0 -> type4(type2(x), ...)
378     //   vector.y111 -> type4(vector.y, ...)
379     //   vector.z10x -> type4(vector.zx, ...)
380     //
381     // The constructor will have at most three arguments: { base expr, constant 0, constant 1 }
382     ExpressionArray constructorArgs;
383     constructorArgs.reserve_back(3);
384     constructorArgs.push_back(std::move(expr));
385 
386     // Apply another swizzle to shuffle the constants into the correct place. Any constant values we
387     // need are also tacked on to the end of the constructor.
388     //   scalar.x0x0 -> type4(type2(x), 0).xyxy
389     //   vector.y111 -> type4(vector.y, 1).xyyy
390     //   vector.z10x -> type4(vector.zx, 1, 0).xzwy
391     const Type* scalarType = &baseType.componentType();
392     ComponentArray swizzleComponents;
393     int maskFieldIdx = 0;
394     int constantFieldIdx = maskComponents.size();
395     int constantZeroIdx = -1, constantOneIdx = -1;
396 
397     for (int i = 0; i < inComponents.count(); i++) {
398         switch (inComponents[i]) {
399             case SwizzleComponent::ZERO:
400                 if (constantZeroIdx == -1) {
401                     // Synthesize a 'type(0)' argument at the end of the constructor.
402                     constructorArgs.push_back(ConstructorScalarCast::Make(
403                             context, line, *scalarType,
404                             Literal::MakeInt(context, line, /*value=*/0)));
405                     constantZeroIdx = constantFieldIdx++;
406                 }
407                 swizzleComponents.push_back(constantZeroIdx);
408                 break;
409             case SwizzleComponent::ONE:
410                 if (constantOneIdx == -1) {
411                     // Synthesize a 'type(1)' argument at the end of the constructor.
412                     constructorArgs.push_back(ConstructorScalarCast::Make(
413                             context, line, *scalarType,
414                             Literal::MakeInt(context, line, /*value=*/1)));
415                     constantOneIdx = constantFieldIdx++;
416                 }
417                 swizzleComponents.push_back(constantOneIdx);
418                 break;
419             default:
420                 // The non-constant fields are already in the expected order.
421                 swizzleComponents.push_back(maskFieldIdx++);
422                 break;
423         }
424     }
425 
426     expr = Constructor::Convert(context, line,
427                                 scalarType->toCompound(context, constantFieldIdx, /*rows=*/1),
428                                 std::move(constructorArgs));
429     if (!expr) {
430         return nullptr;
431     }
432 
433     return Swizzle::Make(context, std::move(expr), swizzleComponents);
434 }
435 
Make(const Context & context,std::unique_ptr<Expression> expr,ComponentArray components)436 std::unique_ptr<Expression> Swizzle::Make(const Context& context,
437                                           std::unique_ptr<Expression> expr,
438                                           ComponentArray components) {
439     const Type& exprType = expr->type();
440     SkASSERTF(exprType.isVector() || exprType.isScalar(),
441               "cannot swizzle type '%s'", exprType.description().c_str());
442     SkASSERT(components.count() >= 1 && components.count() <= 4);
443 
444     // Confirm that the component array only contains X/Y/Z/W. (Call MakeWith01 if you want support
445     // for ZERO and ONE. Once initial IR generation is complete, no swizzles should have zeros or
446     // ones in them.)
447     SkASSERT(std::all_of(components.begin(), components.end(), [](int8_t component) {
448         return component >= SwizzleComponent::X &&
449                component <= SwizzleComponent::W;
450     }));
451 
452     // SkSL supports splatting a scalar via `scalar.xxxx`, but not all versions of GLSL allow this.
453     // Replace swizzles with equivalent splat constructors (`scalar.xxx` --> `half3(value)`).
454     if (exprType.isScalar()) {
455         int line = expr->fLine;
456         return ConstructorSplat::Make(context, line,
457                                       exprType.toCompound(context, components.size(), /*rows=*/1),
458                                       std::move(expr));
459     }
460 
461     // Detect identity swizzles like `color.rgba` and return the base-expression as-is.
462     if (components.count() == exprType.columns()) {
463         bool identity = true;
464         for (int i = 0; i < components.count(); ++i) {
465             if (components[i] != i) {
466                 identity = false;
467                 break;
468             }
469         }
470         if (identity) {
471             return expr;
472         }
473     }
474 
475     // Optimize swizzles of swizzles, e.g. replace `foo.argb.rggg` with `foo.arrr`.
476     if (expr->is<Swizzle>()) {
477         Swizzle& base = expr->as<Swizzle>();
478         ComponentArray combined;
479         for (int8_t c : components) {
480             combined.push_back(base.components()[c]);
481         }
482 
483         // It may actually be possible to further simplify this swizzle. Go again.
484         // (e.g. `color.abgr.abgr` --> `color.rgba` --> `color`.)
485         return Swizzle::Make(context, std::move(base.base()), combined);
486     }
487 
488     // If we are swizzling a constant expression, we can use its value instead here (so that
489     // swizzles like `colorWhite.x` can be simplified to `1`).
490     const Expression* value = ConstantFolder::GetConstantValueForVariable(*expr);
491 
492     // `half4(scalar).zyy` can be optimized to `half3(scalar)`, and `half3(scalar).y` can be
493     // optimized to just `scalar`. The swizzle components don't actually matter, as every field
494     // in a splat constructor holds the same value.
495     if (value->is<ConstructorSplat>()) {
496         const ConstructorSplat& splat = value->as<ConstructorSplat>();
497         return ConstructorSplat::Make(
498                 context, splat.fLine,
499                 splat.type().componentType().toCompound(context, components.size(), /*rows=*/1),
500                 splat.argument()->clone());
501     }
502 
503     // Optimize swizzles of constructors.
504     if (value->isAnyConstructor()) {
505         const AnyConstructor& ctor = value->asAnyConstructor();
506         if (auto replacement = optimize_constructor_swizzle(context, ctor, components)) {
507             return replacement;
508         }
509     }
510 
511     // The swizzle could not be simplified, so apply the requested swizzle to the base expression.
512     return std::make_unique<Swizzle>(context, std::move(expr), components);
513 }
514 
515 }  // namespace SkSL
516