• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <cstring>
9 #include <functional>
10 #include <numeric>
11 #include <unordered_map>
12 #include <unordered_set>
13 
14 #include "compiler/translator/Compiler.h"
15 #include "compiler/translator/msl/AstHelpers.h"
16 #include "compiler/translator/msl/ModifyStruct.h"
17 #include "compiler/translator/msl/TranslatorMSL.h"
18 
19 using namespace sh;
20 
21 ////////////////////////////////////////////////////////////////////////////////
22 
size() const23 size_t ModifiedStructMachineries::size() const
24 {
25     return ordering.size();
26 }
27 
at(size_t index) const28 const ModifiedStructMachinery &ModifiedStructMachineries::at(size_t index) const
29 {
30     ASSERT(index < size());
31     const TStructure *s              = ordering[index];
32     const ModifiedStructMachinery *m = find(*s);
33     ASSERT(m);
34     return *m;
35 }
36 
find(const TStructure & s) const37 const ModifiedStructMachinery *ModifiedStructMachineries::find(const TStructure &s) const
38 {
39     auto iter = originalToMachinery.find(&s);
40     if (iter == originalToMachinery.end())
41     {
42         return nullptr;
43     }
44     return &iter->second;
45 }
46 
insert(const TStructure & s,const ModifiedStructMachinery & machinery)47 void ModifiedStructMachineries::insert(const TStructure &s,
48                                        const ModifiedStructMachinery &machinery)
49 {
50     ASSERT(!find(s));
51     originalToMachinery[&s] = machinery;
52     ordering.push_back(&s);
53 }
54 
55 ////////////////////////////////////////////////////////////////////////////////
56 
57 namespace
58 {
59 
Flatten(SymbolEnv & symbolEnv,TIntermTyped & node)60 TIntermTyped &Flatten(SymbolEnv &symbolEnv, TIntermTyped &node)
61 {
62     auto &type = node.getType();
63     ASSERT(type.isArray());
64 
65     auto &retType = InnermostType(type);
66     retType.makeArray(1);
67 
68     return symbolEnv.callFunctionOverload(Name("flatten"), retType, *new TIntermSequence{&node});
69 }
70 
71 struct FlattenArray
72 {};
73 
74 struct PathItem
75 {
76     enum class Type
77     {
78         Field,         // Struct field indexing.
79         Index,         // Array, vector, or matrix indexing.
80         FlattenArray,  // Array of any rank -> pointer of innermost type.
81     };
82 
PathItem__anon7cb117bf0111::PathItem83     PathItem(const TField &field) : field(&field), type(Type::Field) {}
PathItem__anon7cb117bf0111::PathItem84     PathItem(int index) : index(index), type(Type::Index) {}
PathItem__anon7cb117bf0111::PathItem85     PathItem(unsigned index) : PathItem(static_cast<int>(index)) {}
PathItem__anon7cb117bf0111::PathItem86     PathItem(FlattenArray flatten) : type(Type::FlattenArray) {}
87 
88     union
89     {
90         const TField *field;
91         int index;
92     };
93     Type type;
94 };
95 
BuildPathAccess(SymbolEnv & symbolEnv,const TVariable & var,const std::vector<PathItem> & path)96 TIntermTyped &BuildPathAccess(SymbolEnv &symbolEnv,
97                               const TVariable &var,
98                               const std::vector<PathItem> &path)
99 {
100     TIntermTyped *curr = new TIntermSymbol(&var);
101     for (const PathItem &item : path)
102     {
103         switch (item.type)
104         {
105             case PathItem::Type::Field:
106                 curr = &AccessField(*curr, item.field->name());
107                 break;
108             case PathItem::Type::Index:
109                 curr = &AccessIndex(*curr, item.index);
110                 break;
111             case PathItem::Type::FlattenArray:
112             {
113                 curr = &Flatten(symbolEnv, *curr);
114             }
115             break;
116         }
117     }
118     return *curr;
119 }
120 
121 ////////////////////////////////////////////////////////////////////////////////
122 
123 using OriginalParam = const TVariable &;
124 using ModifiedParam = const TVariable &;
125 
126 using OriginalAccess = TIntermTyped;
127 using ModifiedAccess = TIntermTyped;
128 
129 struct Access
130 {
131     OriginalAccess &original;
132     ModifiedAccess &modified;
133 
134     struct Env
135     {
136         const ConvertType type;
137     };
138 };
139 
140 using ConversionFunc = std::function<Access(Access::Env &, OriginalAccess &, ModifiedAccess &)>;
141 
142 class ConvertStructState : angle::NonCopyable
143 {
144   private:
145     struct ConversionInfo
146     {
147         ConversionFunc stdFunc;
148         const TFunction *astFunc;
149         std::vector<PathItem> pathItems;
150         ImmutableString pathName;
151     };
152 
153   public:
ConvertStructState(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,ModifiedStructMachineries & outMachineries,const bool isUBO)154     ConvertStructState(TCompiler &compiler,
155                        SymbolEnv &symbolEnv,
156                        IdGen &idGen,
157                        const ModifyStructConfig &config,
158                        ModifiedStructMachineries &outMachineries,
159                        const bool isUBO)
160         : mCompiler(compiler),
161           config(config),
162           symbolEnv(symbolEnv),
163           modifiedFields(*new TFieldList()),
164           symbolTable(symbolEnv.symbolTable()),
165           idGen(idGen),
166           outMachineries(outMachineries),
167           isUBO(isUBO)
168     {}
169 
~ConvertStructState()170     ~ConvertStructState()
171     {
172         ASSERT(namePath.empty());
173         ASSERT(namePathSizes.empty());
174     }
175 
publish(const TStructure & originalStruct,const Name & modifiedStructName)176     void publish(const TStructure &originalStruct, const Name &modifiedStructName)
177     {
178         const bool isOriginalToModified = config.convertType == ConvertType::OriginalToModified;
179 
180         auto &modifiedStruct = *new TStructure(&symbolTable, modifiedStructName.rawName(),
181                                                &modifiedFields, modifiedStructName.symbolType());
182 
183         auto &func = *new TFunction(
184             &symbolTable,
185             idGen.createNewName(isOriginalToModified ? "originalToModified" : "modifiedToOriginal")
186                 .rawName(),
187             SymbolType::AngleInternal, new TType(TBasicType::EbtVoid), false);
188 
189         OriginalParam originalParam =
190             CreateInstanceVariable(symbolTable, originalStruct, Name("original"));
191         ModifiedParam modifiedParam =
192             CreateInstanceVariable(symbolTable, modifiedStruct, Name("modified"));
193 
194         symbolEnv.markAsReference(originalParam, AddressSpace::Thread);
195         symbolEnv.markAsReference(modifiedParam, config.externalAddressSpace);
196         if (isOriginalToModified)
197         {
198             func.addParameter(&originalParam);
199             func.addParameter(&modifiedParam);
200         }
201         else
202         {
203             func.addParameter(&modifiedParam);
204             func.addParameter(&originalParam);
205         }
206 
207         TIntermBlock &body = *new TIntermBlock();
208 
209         Access::Env env{config.convertType};
210 
211         for (ConversionInfo &info : conversionInfos)
212         {
213             auto convert = [&](OriginalAccess &original, ModifiedAccess &modified) {
214                 if (info.astFunc)
215                 {
216                     ASSERT(!info.stdFunc);
217                     TIntermTyped &src  = isOriginalToModified ? modified : original;
218                     TIntermTyped &dest = isOriginalToModified ? original : modified;
219                     body.appendStatement(TIntermAggregate::CreateFunctionCall(
220                         *info.astFunc, new TIntermSequence{&dest, &src}));
221                 }
222                 else
223                 {
224                     ASSERT(info.stdFunc);
225                     Access access      = info.stdFunc(env, original, modified);
226                     TIntermTyped &src  = isOriginalToModified ? access.original : access.modified;
227                     TIntermTyped &dest = isOriginalToModified ? access.modified : access.original;
228                     body.appendStatement(new TIntermBinary(TOperator::EOpAssign, &dest, &src));
229                 }
230             };
231 
232             OriginalAccess *original = &BuildPathAccess(symbolEnv, originalParam, info.pathItems);
233             ModifiedAccess *modified = &AccessField(modifiedParam, info.pathName);
234 
235             const TType ot = original->getType();
236             const TType mt = modified->getType();
237             ASSERT(ot.isArray() == mt.isArray());
238 
239             // Clip distance output uses float[n] type, so the field must be assigned per-element
240             // when filling the modified struct. Explicit path name is used because original types
241             // are not available here.
242             if (ot.isArray() && (ot.getLayoutQualifier().matrixPacking == EmpRowMajor || ot != mt ||
243                                  info.pathName == ImmutableString("gl_ClipDistance")))
244             {
245                 ASSERT(ot.getArraySizes() == mt.getArraySizes());
246                 if (ot.isArrayOfArrays())
247                 {
248                     original = &Flatten(symbolEnv, *original);
249                     modified = &Flatten(symbolEnv, *modified);
250                 }
251                 const int volume = static_cast<int>(ot.getArraySizeProduct());
252                 for (int i = 0; i < volume; ++i)
253                 {
254                     if (i != 0)
255                     {
256                         original = original->deepCopy();
257                         modified = modified->deepCopy();
258                     }
259                     OriginalAccess &o = AccessIndex(*original, i);
260                     OriginalAccess &m = AccessIndex(*modified, i);
261                     convert(o, m);
262                 }
263             }
264             else
265             {
266                 convert(*original, *modified);
267             }
268         }
269 
270         auto *funcProto = new TIntermFunctionPrototype(&func);
271         auto *funcDef   = new TIntermFunctionDefinition(funcProto, &body);
272 
273         ModifiedStructMachinery machinery;
274         machinery.modifiedStruct                   = &modifiedStruct;
275         machinery.getConverter(config.convertType) = funcDef;
276 
277         outMachineries.insert(originalStruct, machinery);
278     }
279 
pushPath(PathItem const & item)280     void pushPath(PathItem const &item)
281     {
282         pathItems.push_back(item);
283 
284         switch (item.type)
285         {
286             case PathItem::Type::Field:
287                 pushNamePath(item.field->name().data());
288                 break;
289 
290             case PathItem::Type::Index:
291                 pushNamePath(item.index);
292                 break;
293 
294             case PathItem::Type::FlattenArray:
295                 namePathSizes.push_back(namePath.size());
296                 break;
297         }
298     }
299 
popPath()300     void popPath()
301     {
302         ASSERT(!namePath.empty());
303         ASSERT(!namePathSizes.empty());
304         namePath.resize(namePathSizes.back());
305         namePathSizes.pop_back();
306 
307         ASSERT(!pathItems.empty());
308         pathItems.pop_back();
309     }
310 
finalize(const bool allowPadding)311     void finalize(const bool allowPadding)
312     {
313         ASSERT(!finalized);
314         finalized = true;
315         introducePacking();
316         ASSERT(metalLayoutTotal == Layout::Identity());
317         // Only pad substructs. We don't want to pad the structure that contains all the UBOs, only
318         // individual UBOs.
319         if (allowPadding)
320             introducePadding();
321     }
322 
addModifiedField(const TField & field,TType & newType,TLayoutBlockStorage storage,TLayoutMatrixPacking packing,const AddressSpace * addressSpace)323     void addModifiedField(const TField &field,
324                           TType &newType,
325                           TLayoutBlockStorage storage,
326                           TLayoutMatrixPacking packing,
327                           const AddressSpace *addressSpace)
328     {
329         TLayoutQualifier layoutQualifier = newType.getLayoutQualifier();
330         layoutQualifier.blockStorage     = storage;
331         layoutQualifier.matrixPacking    = packing;
332         newType.setLayoutQualifier(layoutQualifier);
333 
334         const ImmutableString pathName(namePath);
335         TField *modifiedField = new TField(&newType, pathName, field.line(), field.symbolType());
336         if (addressSpace)
337         {
338             symbolEnv.markAsPointer(*modifiedField, *addressSpace);
339         }
340         if (symbolEnv.isUBO(field))
341         {
342             symbolEnv.markAsUBO(*modifiedField);
343         }
344         modifiedFields.push_back(modifiedField);
345     }
346 
addConversion(const ConversionFunc & func)347     void addConversion(const ConversionFunc &func)
348     {
349         ASSERT(!modifiedFields.empty());
350         conversionInfos.push_back({func, nullptr, pathItems, modifiedFields.back()->name()});
351     }
352 
addConversion(const TFunction & func)353     void addConversion(const TFunction &func)
354     {
355         ASSERT(!modifiedFields.empty());
356         conversionInfos.push_back({{}, &func, pathItems, modifiedFields.back()->name()});
357     }
358 
hasPacking() const359     bool hasPacking() const { return containsPacked; }
360 
hasPadding() const361     bool hasPadding() const { return padFieldCount > 0; }
362 
recurse(const TStructure & structure,ModifiedStructMachinery & outMachinery,const bool isUBORecurse)363     bool recurse(const TStructure &structure,
364                  ModifiedStructMachinery &outMachinery,
365                  const bool isUBORecurse)
366     {
367         const ModifiedStructMachinery *m = outMachineries.find(structure);
368         if (m == nullptr)
369         {
370             TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
371             reflection->addOriginalName(structure.uniqueId().get(), structure.name().data());
372             const Name name = idGen.createNewName(structure.name().data());
373             if (!TryCreateModifiedStruct(mCompiler, symbolEnv, idGen, config, structure, name,
374                                          outMachineries, isUBORecurse, config.allowPadding))
375             {
376                 return false;
377             }
378             m = outMachineries.find(structure);
379             ASSERT(m);
380         }
381         outMachinery = *m;
382         return true;
383     }
384 
getIsUBO() const385     bool getIsUBO() const { return isUBO; }
386 
387   private:
addPadding(size_t padAmount,bool updateLayout)388     void addPadding(size_t padAmount, bool updateLayout)
389     {
390         if (padAmount == 0)
391         {
392             return;
393         }
394 
395         const size_t begin = modifiedFields.size();
396 
397         // Iteratively adding in scalar or vector padding because some struct types will not
398         // allow matrix or array members.
399         while (padAmount > 0)
400         {
401             TType *padType;
402             if (padAmount >= 16)
403             {
404                 padAmount -= 16;
405                 padType = new TType(TBasicType::EbtFloat, 4);
406             }
407             else if (padAmount >= 8)
408             {
409                 padAmount -= 8;
410                 padType = new TType(TBasicType::EbtFloat, 2);
411             }
412             else if (padAmount >= 4)
413             {
414                 padAmount -= 4;
415                 padType = new TType(TBasicType::EbtFloat);
416             }
417             else if (padAmount >= 2)
418             {
419                 padAmount -= 2;
420                 padType = new TType(TBasicType::EbtBool, 2);
421             }
422             else
423             {
424                 ASSERT(padAmount == 1);
425                 padAmount -= 1;
426                 padType = new TType(TBasicType::EbtBool);
427             }
428 
429             if (padType->getBasicType() != EbtBool)
430             {
431                 padType->setPrecision(EbpLow);
432             }
433 
434             if (updateLayout)
435             {
436                 metalLayoutTotal += MetalLayoutOf(*padType);
437             }
438 
439             const Name name = idGen.createNewName("pad");
440             modifiedFields.push_back(
441                 new TField(padType, name.rawName(), kNoSourceLoc, name.symbolType()));
442             ++padFieldCount;
443         }
444 
445         std::reverse(modifiedFields.begin() + begin, modifiedFields.end());
446     }
447 
introducePacking()448     void introducePacking()
449     {
450         if (!config.allowPacking)
451         {
452             return;
453         }
454 
455         auto setUnpackedStorage = [](TType &type) {
456             TLayoutBlockStorage storage = type.getLayoutQualifier().blockStorage;
457             switch (storage)
458             {
459                 case TLayoutBlockStorage::EbsShared:
460                     storage = TLayoutBlockStorage::EbsStd140;
461                     break;
462                 case TLayoutBlockStorage::EbsPacked:
463                     storage = TLayoutBlockStorage::EbsStd430;
464                     break;
465                 case TLayoutBlockStorage::EbsStd140:
466                 case TLayoutBlockStorage::EbsStd430:
467                 case TLayoutBlockStorage::EbsUnspecified:
468                     break;
469             }
470             SetBlockStorage(type, storage);
471         };
472 
473         Layout glslLayoutTotal = Layout::Identity();
474         const size_t size      = modifiedFields.size();
475 
476         for (size_t i = 0; i < size; ++i)
477         {
478             TField &curr           = *modifiedFields[i];
479             TType &currType        = *curr.type();
480             const bool canBePacked = CanBePacked(currType);
481 
482             auto dontPack = [&]() {
483                 if (canBePacked)
484                 {
485                     setUnpackedStorage(currType);
486                 }
487                 glslLayoutTotal += GlslLayoutOf(currType);
488             };
489 
490             if (!CanBePacked(currType))
491             {
492                 dontPack();
493                 continue;
494             }
495 
496             const Layout packedGlslLayout           = GlslLayoutOf(currType);
497             const TLayoutBlockStorage packedStorage = currType.getLayoutQualifier().blockStorage;
498             setUnpackedStorage(currType);
499             const Layout unpackedGlslLayout = GlslLayoutOf(currType);
500             SetBlockStorage(currType, packedStorage);
501 
502             ASSERT(packedGlslLayout.sizeOf <= unpackedGlslLayout.sizeOf);
503             if (packedGlslLayout.sizeOf == unpackedGlslLayout.sizeOf)
504             {
505                 dontPack();
506                 continue;
507             }
508 
509             const size_t j = i + 1;
510             if (j == size)
511             {
512                 dontPack();
513                 break;
514             }
515 
516             const size_t pad            = unpackedGlslLayout.sizeOf - packedGlslLayout.sizeOf;
517             const TField &next          = *modifiedFields[j];
518             const Layout nextGlslLayout = GlslLayoutOf(*next.type());
519 
520             if (pad < nextGlslLayout.sizeOf)
521             {
522                 dontPack();
523                 continue;
524             }
525 
526             symbolEnv.markAsPacked(curr);
527             glslLayoutTotal += packedGlslLayout;
528             containsPacked = true;
529         }
530     }
531 
introducePadding()532     void introducePadding()
533     {
534         if (!config.allowPadding)
535         {
536             return;
537         }
538 
539         MetalLayoutOfConfig layoutConfig;
540         layoutConfig.disablePacking             = !config.allowPacking;
541         layoutConfig.assumeStructsAreTailPadded = true;
542 
543         TFieldList fields = std::move(modifiedFields);
544         ASSERT(!fields.empty());  // GLSL requires at least one member.
545 
546         const TField *const first = fields.front();
547 
548         for (TField *field : fields)
549         {
550             const TType &type = *field->type();
551 
552             const Layout glslLayout  = GlslLayoutOf(type);
553             const Layout metalLayout = MetalLayoutOf(type, layoutConfig);
554 
555             size_t prePadAmount = 0;
556             if (glslLayout.alignOf > metalLayout.alignOf && field != first)
557             {
558                 const size_t prePaddedSize = metalLayoutTotal.sizeOf;
559                 metalLayoutTotal.requireAlignment(glslLayout.alignOf, true);
560                 const size_t paddedSize = metalLayoutTotal.sizeOf;
561                 prePadAmount            = paddedSize - prePaddedSize;
562                 metalLayoutTotal += metalLayout;
563                 addPadding(prePadAmount, false);  // Note: requireAlignment() already updated layout
564             }
565             else
566             {
567                 metalLayoutTotal += metalLayout;
568             }
569 
570             modifiedFields.push_back(field);
571 
572             if (glslLayout.sizeOf > metalLayout.sizeOf && field != fields.back())
573             {
574                 const bool updateLayout = true;  // XXX: Correct?
575                 const size_t padAmount  = glslLayout.sizeOf - metalLayout.sizeOf;
576                 addPadding(padAmount, updateLayout);
577             }
578         }
579     }
580 
pushNamePath(const char * extra)581     void pushNamePath(const char *extra)
582     {
583         ASSERT(extra && *extra != '\0');
584         namePathSizes.push_back(namePath.size());
585         const char *p = extra;
586         if (namePath.empty())
587         {
588             namePath = p;
589             return;
590         }
591         while (*p == '_')
592         {
593             ++p;
594         }
595         if (*p == '\0')
596         {
597             p = "x";
598         }
599         if (namePath.back() != '_')
600         {
601             namePath += '_';
602         }
603         namePath += p;
604     }
605 
pushNamePath(unsigned extra)606     void pushNamePath(unsigned extra)
607     {
608         char buffer[std::numeric_limits<unsigned>::digits10 + 1];
609         snprintf(buffer, sizeof(buffer), "%u", extra);
610         pushNamePath(buffer);
611     }
612 
613   public:
614     TCompiler &mCompiler;
615     const ModifyStructConfig &config;
616     SymbolEnv &symbolEnv;
617 
618   private:
619     TFieldList &modifiedFields;
620     Layout metalLayoutTotal = Layout::Identity();
621     size_t padFieldCount    = 0;
622     bool containsPacked     = false;
623     bool finalized          = false;
624 
625     std::vector<PathItem> pathItems;
626 
627     std::vector<size_t> namePathSizes;
628     std::string namePath;
629 
630     std::vector<ConversionInfo> conversionInfos;
631     TSymbolTable &symbolTable;
632     IdGen &idGen;
633     ModifiedStructMachineries &outMachineries;
634     const bool isUBO;
635 };
636 
637 ////////////////////////////////////////////////////////////////////////////////
638 
639 using ModifyFunc = bool (*)(ConvertStructState &state,
640                             const TField &field,
641                             const TLayoutBlockStorage storage,
642                             const TLayoutMatrixPacking packing);
643 
644 bool ModifyRecursive(ConvertStructState &state,
645                      const TField &field,
646                      const TLayoutBlockStorage storage,
647                      const TLayoutMatrixPacking packing);
648 
IdentityModify(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)649 bool IdentityModify(ConvertStructState &state,
650                     const TField &field,
651                     const TLayoutBlockStorage storage,
652                     const TLayoutMatrixPacking packing)
653 {
654     const TType &type = *field.type();
655     state.addModifiedField(field, CloneType(type), storage, packing, nullptr);
656     state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
657         return Access{o, m};
658     });
659     return false;
660 }
661 
InlineStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)662 bool InlineStruct(ConvertStructState &state,
663                   const TField &field,
664                   const TLayoutBlockStorage storage,
665                   const TLayoutMatrixPacking packing)
666 {
667     const TType &type              = *field.type();
668     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
669     if (!substructure)
670     {
671         return false;
672     }
673     if (type.isArray())
674     {
675         return false;
676     }
677     if (!state.config.inlineStruct(field))
678     {
679         return false;
680     }
681 
682     const TFieldList &subfields = substructure->fields();
683     for (const TField *subfield : subfields)
684     {
685         const TType &subtype                  = *subfield->type();
686         const TLayoutBlockStorage substorage  = Overlay(storage, subtype);
687         const TLayoutMatrixPacking subpacking = Overlay(packing, subtype);
688         ModifyRecursive(state, *subfield, substorage, subpacking);
689     }
690 
691     return true;
692 }
693 
RecurseStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)694 bool RecurseStruct(ConvertStructState &state,
695                    const TField &field,
696                    const TLayoutBlockStorage storage,
697                    const TLayoutMatrixPacking packing)
698 {
699     const TType &type              = *field.type();
700     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
701     if (!substructure)
702     {
703         return false;
704     }
705     if (!state.config.recurseStruct(field))
706     {
707         return false;
708     }
709 
710     ModifiedStructMachinery machinery;
711     if (!state.recurse(*substructure, machinery, state.getIsUBO()))
712     {
713         return false;
714     }
715 
716     TType &newType = *new TType(machinery.modifiedStruct, false);
717     if (type.isArray())
718     {
719         newType.makeArrays(type.getArraySizes());
720     }
721 
722     TIntermFunctionDefinition *converter = machinery.getConverter(state.config.convertType);
723     ASSERT(converter);
724 
725     state.addModifiedField(field, newType, storage, packing, state.symbolEnv.isPointer(field));
726     if (state.symbolEnv.isPointer(field))
727     {
728         state.symbolEnv.removePointer(field);
729     }
730     state.addConversion(*converter->getFunction());
731 
732     return true;
733 }
734 
SplitMatrixColumns(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)735 bool SplitMatrixColumns(ConvertStructState &state,
736                         const TField &field,
737                         const TLayoutBlockStorage storage,
738                         const TLayoutMatrixPacking packing)
739 {
740     const TType &type = *field.type();
741     if (!type.isMatrix())
742     {
743         return false;
744     }
745 
746     if (!state.config.splitMatrixColumns(field))
747     {
748         return false;
749     }
750 
751     const uint8_t cols = type.getCols();
752     TType &rowType     = DropColumns(type);
753 
754     for (uint8_t c = 0; c < cols; ++c)
755     {
756         state.pushPath(c);
757 
758         state.addModifiedField(field, rowType, storage, packing, state.symbolEnv.isPointer(field));
759         if (state.symbolEnv.isPointer(field))
760         {
761             state.symbolEnv.removePointer(field);
762         }
763         state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
764             return Access{o, m};
765         });
766 
767         state.popPath();
768     }
769 
770     return true;
771 }
772 
SaturateMatrixRows(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)773 bool SaturateMatrixRows(ConvertStructState &state,
774                         const TField &field,
775                         const TLayoutBlockStorage storage,
776                         const TLayoutMatrixPacking packing)
777 {
778     const TType &type = *field.type();
779     if (!type.isMatrix())
780     {
781         return false;
782     }
783     const bool isRowMajor    = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
784     const uint8_t rows       = type.getRows();
785     const uint8_t saturation = state.config.saturateMatrixRows(field);
786     if (saturation <= rows)
787     {
788         return false;
789     }
790 
791     const uint8_t cols = type.getCols();
792     TType &satType     = SetMatrixRowDim(type, saturation);
793     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
794     if (state.symbolEnv.isPointer(field))
795     {
796         state.symbolEnv.removePointer(field);
797     }
798 
799     for (uint8_t c = 0; c < cols; ++c)
800     {
801         for (uint8_t r = 0; r < rows; ++r)
802         {
803             state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
804                 uint8_t firstModifiedIndex  = isRowMajor ? r : c;
805                 uint8_t secondModifiedIndex = isRowMajor ? c : r;
806                 auto &o_                    = AccessIndex(AccessIndex(o, c), r);
807                 auto &m_ = AccessIndex(AccessIndex(m, firstModifiedIndex), secondModifiedIndex);
808                 return Access{o_, m_};
809             });
810         }
811     }
812 
813     return true;
814 }
815 
TestBoolToUint(ConvertStructState & state,const TField & field)816 bool TestBoolToUint(ConvertStructState &state, const TField &field)
817 {
818     if (field.type()->getBasicType() != TBasicType::EbtBool)
819     {
820         return false;
821     }
822     if (!state.config.promoteBoolToUint(field))
823     {
824         return false;
825     }
826     return true;
827 }
828 
ConvertBoolToUint(ConvertType convertType,OriginalAccess & o,ModifiedAccess & m)829 Access ConvertBoolToUint(ConvertType convertType, OriginalAccess &o, ModifiedAccess &m)
830 {
831     auto coerce = [](TIntermTyped &to, TIntermTyped &from) -> TIntermTyped & {
832         return *TIntermAggregate::CreateConstructor(to.getType(), new TIntermSequence{&from});
833     };
834     switch (convertType)
835     {
836         case ConvertType::OriginalToModified:
837             return Access{coerce(m, o), m};
838         case ConvertType::ModifiedToOriginal:
839             return Access{o, coerce(o, m)};
840     }
841 }
842 
SaturateScalarOrVectorCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing,const bool array)843 bool SaturateScalarOrVectorCommon(ConvertStructState &state,
844                                   const TField &field,
845                                   const TLayoutBlockStorage storage,
846                                   const TLayoutMatrixPacking packing,
847                                   const bool array)
848 {
849     const TType &type = *field.type();
850     if (type.isArray() != array)
851     {
852         return false;
853     }
854     if (!((type.isRank0() && HasScalarBasicType(type)) || type.isVector()))
855     {
856         return false;
857     }
858     const auto saturator =
859         array ? state.config.saturateScalarOrVectorArrays : state.config.saturateScalarOrVector;
860     const uint8_t dim        = type.getNominalSize();
861     const uint8_t saturation = saturator(field);
862     if (saturation <= dim)
863     {
864         return false;
865     }
866 
867     TType &satType        = SetVectorDim(type, saturation);
868     const bool boolToUint = TestBoolToUint(state, field);
869     if (boolToUint)
870     {
871         satType.setBasicType(TBasicType::EbtUInt);
872     }
873     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
874     if (state.symbolEnv.isPointer(field))
875     {
876         state.symbolEnv.removePointer(field);
877     }
878 
879     for (uint8_t d = 0; d < dim; ++d)
880     {
881         state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
882             auto &o_ = dim > 1 ? AccessIndex(o, d) : o;
883             auto &m_ = AccessIndex(m, d);
884             if (boolToUint)
885             {
886                 return ConvertBoolToUint(env.type, o_, m_);
887             }
888             else
889             {
890                 return Access{o_, m_};
891             }
892         });
893     }
894 
895     return true;
896 }
897 
SaturateScalarOrVectorArrays(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)898 bool SaturateScalarOrVectorArrays(ConvertStructState &state,
899                                   const TField &field,
900                                   const TLayoutBlockStorage storage,
901                                   const TLayoutMatrixPacking packing)
902 {
903     return SaturateScalarOrVectorCommon(state, field, storage, packing, true);
904 }
905 
SaturateScalarOrVector(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)906 bool SaturateScalarOrVector(ConvertStructState &state,
907                             const TField &field,
908                             const TLayoutBlockStorage storage,
909                             const TLayoutMatrixPacking packing)
910 {
911     return SaturateScalarOrVectorCommon(state, field, storage, packing, false);
912 }
913 
PromoteBoolToUint(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)914 bool PromoteBoolToUint(ConvertStructState &state,
915                        const TField &field,
916                        const TLayoutBlockStorage storage,
917                        const TLayoutMatrixPacking packing)
918 {
919     if (!TestBoolToUint(state, field))
920     {
921         return false;
922     }
923 
924     auto &promotedType = CloneType(*field.type());
925     promotedType.setBasicType(TBasicType::EbtUInt);
926     state.addModifiedField(field, promotedType, storage, packing, state.symbolEnv.isPointer(field));
927     if (state.symbolEnv.isPointer(field))
928     {
929         state.symbolEnv.removePointer(field);
930     }
931 
932     state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
933         return ConvertBoolToUint(env.type, o, m);
934     });
935 
936     return true;
937 }
938 
ModifyCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)939 bool ModifyCommon(ConvertStructState &state,
940                   const TField &field,
941                   const TLayoutBlockStorage storage,
942                   const TLayoutMatrixPacking packing)
943 {
944     ModifyFunc funcs[] = {
945         InlineStruct,                  //
946         RecurseStruct,                 //
947         SplitMatrixColumns,            //
948         SaturateMatrixRows,            //
949         SaturateScalarOrVectorArrays,  //
950         SaturateScalarOrVector,        //
951         PromoteBoolToUint,             //
952     };
953 
954     for (ModifyFunc func : funcs)
955     {
956         if (func(state, field, storage, packing))
957         {
958             return true;
959         }
960     }
961 
962     return IdentityModify(state, field, storage, packing);
963 }
964 
InlineArray(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)965 bool InlineArray(ConvertStructState &state,
966                  const TField &field,
967                  const TLayoutBlockStorage storage,
968                  const TLayoutMatrixPacking packing)
969 {
970     const TType &type = *field.type();
971     if (!type.isArray())
972     {
973         return false;
974     }
975     if (!state.config.inlineArray(field))
976     {
977         return false;
978     }
979 
980     const unsigned volume = type.getArraySizeProduct();
981     const bool isMultiDim = type.isArrayOfArrays();
982 
983     auto &innermostType = InnermostType(type);
984 
985     if (isMultiDim)
986     {
987         state.pushPath(FlattenArray());
988     }
989 
990     for (unsigned i = 0; i < volume; ++i)
991     {
992         state.pushPath(i);
993         TType setType(innermostType);
994         if (setType.getLayoutQualifier().locationsSpecified)
995         {
996             TLayoutQualifier qualifier(innermostType.getLayoutQualifier());
997             qualifier.location           = innermostType.getLayoutQualifier().location + i;
998             qualifier.locationsSpecified = 1;
999             setType.setLayoutQualifier(qualifier);
1000         }
1001         const TField innermostField(&setType, field.name(), field.line(), field.symbolType());
1002         ModifyCommon(state, innermostField, storage, packing);
1003         state.popPath();
1004     }
1005 
1006     if (isMultiDim)
1007     {
1008         state.popPath();
1009     }
1010 
1011     return true;
1012 }
1013 
ModifyRecursive(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)1014 bool ModifyRecursive(ConvertStructState &state,
1015                      const TField &field,
1016                      const TLayoutBlockStorage storage,
1017                      const TLayoutMatrixPacking packing)
1018 {
1019     state.pushPath(field);
1020 
1021     bool modified;
1022     if (InlineArray(state, field, storage, packing))
1023     {
1024         modified = true;
1025     }
1026     else
1027     {
1028         modified = ModifyCommon(state, field, storage, packing);
1029     }
1030 
1031     state.popPath();
1032 
1033     return modified;
1034 }
1035 
1036 }  // anonymous namespace
1037 
1038 ////////////////////////////////////////////////////////////////////////////////
1039 
TryCreateModifiedStruct(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,const TStructure & originalStruct,const Name & modifiedStructName,ModifiedStructMachineries & outMachineries,const bool isUBO,const bool allowPadding)1040 bool sh::TryCreateModifiedStruct(TCompiler &compiler,
1041                                  SymbolEnv &symbolEnv,
1042                                  IdGen &idGen,
1043                                  const ModifyStructConfig &config,
1044                                  const TStructure &originalStruct,
1045                                  const Name &modifiedStructName,
1046                                  ModifiedStructMachineries &outMachineries,
1047                                  const bool isUBO,
1048                                  const bool allowPadding)
1049 {
1050     ConvertStructState state(compiler, symbolEnv, idGen, config, outMachineries, isUBO);
1051     size_t identicalFieldCount = 0;
1052 
1053     const TFieldList &originalFields = originalStruct.fields();
1054     for (TField *originalField : originalFields)
1055     {
1056         const TType &originalType          = *originalField->type();
1057         const TLayoutBlockStorage storage  = Overlay(config.initialBlockStorage, originalType);
1058         const TLayoutMatrixPacking packing = Overlay(config.initialMatrixPacking, originalType);
1059         if (!ModifyRecursive(state, *originalField, storage, packing))
1060         {
1061             ++identicalFieldCount;
1062         }
1063     }
1064 
1065     state.finalize(allowPadding);
1066 
1067     if (identicalFieldCount == originalFields.size() && !state.hasPacking() && !state.hasPadding())
1068     {
1069         return false;
1070     }
1071     state.publish(originalStruct, modifiedStructName);
1072 
1073     return true;
1074 }
1075