• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2020 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 
7 #include <algorithm>
8 #include <cstring>
9 #include <numeric>
10 #include <unordered_map>
11 #include <unordered_set>
12 
13 #include "compiler/translator/Compiler.h"
14 #include "compiler/translator/TranslatorMetalDirect.h"
15 #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h"
16 #include "compiler/translator/TranslatorMetalDirect/ModifyStruct.h"
17 
18 using namespace sh;
19 
20 ////////////////////////////////////////////////////////////////////////////////
21 
size() const22 size_t ModifiedStructMachineries::size() const
23 {
24     return ordering.size();
25 }
26 
at(size_t index) const27 const ModifiedStructMachinery &ModifiedStructMachineries::at(size_t index) const
28 {
29     ASSERT(index < size());
30     const TStructure *s              = ordering[index];
31     const ModifiedStructMachinery *m = find(*s);
32     ASSERT(m);
33     return *m;
34 }
35 
find(const TStructure & s) const36 const ModifiedStructMachinery *ModifiedStructMachineries::find(const TStructure &s) const
37 {
38     auto iter = originalToMachinery.find(&s);
39     if (iter == originalToMachinery.end())
40     {
41         return nullptr;
42     }
43     return &iter->second;
44 }
45 
insert(const TStructure & s,const ModifiedStructMachinery & machinery)46 void ModifiedStructMachineries::insert(const TStructure &s,
47                                        const ModifiedStructMachinery &machinery)
48 {
49     ASSERT(!find(s));
50     originalToMachinery[&s] = machinery;
51     ordering.push_back(&s);
52 }
53 
54 ////////////////////////////////////////////////////////////////////////////////
55 
56 namespace
57 {
58 
Flatten(SymbolEnv & symbolEnv,TIntermTyped & node)59 TIntermTyped &Flatten(SymbolEnv &symbolEnv, TIntermTyped &node)
60 {
61     auto &type = node.getType();
62     ASSERT(type.isArray());
63 
64     auto &retType = InnermostType(type);
65     retType.makeArray(1);
66 
67     return symbolEnv.callFunctionOverload(Name("flatten"), retType, *new TIntermSequence{&node});
68 }
69 
70 struct FlattenArray
71 {};
72 
73 struct PathItem
74 {
75     enum class Type
76     {
77         Field,         // Struct field indexing.
78         Index,         // Array, vector, or matrix indexing.
79         FlattenArray,  // Array of any rank -> pointer of innermost type.
80     };
81 
PathItem__anon1fceafbe0111::PathItem82     PathItem(const TField &field) : field(&field), type(Type::Field) {}
PathItem__anon1fceafbe0111::PathItem83     PathItem(int index) : index(index), type(Type::Index) {}
PathItem__anon1fceafbe0111::PathItem84     PathItem(unsigned index) : PathItem(static_cast<int>(index)) {}
PathItem__anon1fceafbe0111::PathItem85     PathItem(FlattenArray flatten) : type(Type::FlattenArray) {}
86 
87     union
88     {
89         const TField *field;
90         int index;
91     };
92     Type type;
93 };
94 
BuildPathAccess(SymbolEnv & symbolEnv,const TVariable & var,const std::vector<PathItem> & path)95 TIntermTyped &BuildPathAccess(SymbolEnv &symbolEnv,
96                               const TVariable &var,
97                               const std::vector<PathItem> &path)
98 {
99     TIntermTyped *curr = new TIntermSymbol(&var);
100     for (const PathItem &item : path)
101     {
102         switch (item.type)
103         {
104             case PathItem::Type::Field:
105                 curr = &AccessField(*curr, item.field->name());
106                 break;
107             case PathItem::Type::Index:
108                 curr = &AccessIndex(*curr, item.index);
109                 break;
110             case PathItem::Type::FlattenArray:
111             {
112                 curr = &Flatten(symbolEnv, *curr);
113             }
114             break;
115         }
116     }
117     return *curr;
118 }
119 
120 ////////////////////////////////////////////////////////////////////////////////
121 
122 using OriginalParam = const TVariable &;
123 using ModifiedParam = const TVariable &;
124 
125 using OriginalAccess = TIntermTyped;
126 using ModifiedAccess = TIntermTyped;
127 
128 struct Access
129 {
130     OriginalAccess &original;
131     ModifiedAccess &modified;
132 
133     struct Env
134     {
135         const ConvertType type;
136     };
137 };
138 
139 using ConversionFunc = std::function<Access(Access::Env &, OriginalAccess &, ModifiedAccess &)>;
140 
141 class ConvertStructState : angle::NonCopyable
142 {
143   private:
144     struct ConversionInfo
145     {
146         ConversionFunc stdFunc;
147         const TFunction *astFunc;
148         std::vector<PathItem> pathItems;
149         ImmutableString pathName;
150     };
151 
152   public:
ConvertStructState(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,ModifiedStructMachineries & outMachineries,const bool isUBO)153     ConvertStructState(TCompiler &compiler,
154                        SymbolEnv &symbolEnv,
155                        IdGen &idGen,
156                        const ModifyStructConfig &config,
157                        ModifiedStructMachineries &outMachineries,
158                        const bool isUBO)
159         : mCompiler(compiler),
160           config(config),
161           symbolEnv(symbolEnv),
162           modifiedFields(*new TFieldList()),
163           symbolTable(symbolEnv.symbolTable()),
164           idGen(idGen),
165           outMachineries(outMachineries),
166           isUBO(isUBO)
167     {}
168 
~ConvertStructState()169     ~ConvertStructState()
170     {
171         ASSERT(namePath.empty());
172         ASSERT(namePathSizes.empty());
173     }
174 
publish(const TStructure & originalStruct,const Name & modifiedStructName)175     void publish(const TStructure &originalStruct, const Name &modifiedStructName)
176     {
177         const bool isOriginalToModified = config.convertType == ConvertType::OriginalToModified;
178 
179         auto &modifiedStruct = *new TStructure(&symbolTable, modifiedStructName.rawName(),
180                                                &modifiedFields, modifiedStructName.symbolType());
181 
182         auto &func = *new TFunction(
183             &symbolTable,
184             idGen.createNewName(isOriginalToModified ? "originalToModified" : "modifiedToOriginal")
185                 .rawName(),
186             SymbolType::AngleInternal, new TType(TBasicType::EbtVoid), false);
187 
188         OriginalParam originalParam =
189             CreateInstanceVariable(symbolTable, originalStruct, Name("original"));
190         ModifiedParam modifiedParam =
191             CreateInstanceVariable(symbolTable, modifiedStruct, Name("modified"));
192 
193         symbolEnv.markAsReference(originalParam, AddressSpace::Thread);
194         symbolEnv.markAsReference(modifiedParam, config.externalAddressSpace);
195         if (isOriginalToModified)
196         {
197             func.addParameter(&originalParam);
198             func.addParameter(&modifiedParam);
199         }
200         else
201         {
202             func.addParameter(&modifiedParam);
203             func.addParameter(&originalParam);
204         }
205 
206         TIntermBlock &body = *new TIntermBlock();
207 
208         Access::Env env{config.convertType};
209 
210         for (ConversionInfo &info : conversionInfos)
211         {
212             auto convert = [&](OriginalAccess &original, ModifiedAccess &modified) {
213                 if (info.astFunc)
214                 {
215                     ASSERT(!info.stdFunc);
216                     TIntermTyped &src  = isOriginalToModified ? modified : original;
217                     TIntermTyped &dest = isOriginalToModified ? original : modified;
218                     body.appendStatement(TIntermAggregate::CreateFunctionCall(
219                         *info.astFunc, new TIntermSequence{&dest, &src}));
220                 }
221                 else
222                 {
223                     ASSERT(info.stdFunc);
224                     Access access      = info.stdFunc(env, original, modified);
225                     TIntermTyped &src  = isOriginalToModified ? access.original : access.modified;
226                     TIntermTyped &dest = isOriginalToModified ? access.modified : access.original;
227                     body.appendStatement(new TIntermBinary(TOperator::EOpAssign, &dest, &src));
228                 }
229             };
230 
231             OriginalAccess *original = &BuildPathAccess(symbolEnv, originalParam, info.pathItems);
232             ModifiedAccess *modified = &AccessField(modifiedParam, info.pathName);
233 
234             const TType ot = original->getType();
235             const TType mt = modified->getType();
236             ASSERT(ot.isArray() == mt.isArray());
237 
238             if (ot.isArray() && (ot.getLayoutQualifier().matrixPacking == EmpRowMajor || ot != mt))
239             {
240                 ASSERT(ot.getArraySizes() == mt.getArraySizes());
241                 if (ot.isArrayOfArrays())
242                 {
243                     original = &Flatten(symbolEnv, *original);
244                     modified = &Flatten(symbolEnv, *modified);
245                 }
246                 const int volume = static_cast<int>(ot.getArraySizeProduct());
247                 for (int i = 0; i < volume; ++i)
248                 {
249                     if (i != 0)
250                     {
251                         original = original->deepCopy();
252                         modified = modified->deepCopy();
253                     }
254                     OriginalAccess &o = AccessIndex(*original, i);
255                     OriginalAccess &m = AccessIndex(*modified, i);
256                     convert(o, m);
257                 }
258             }
259             else
260             {
261                 convert(*original, *modified);
262             }
263         }
264 
265         auto *funcProto = new TIntermFunctionPrototype(&func);
266         auto *funcDef   = new TIntermFunctionDefinition(funcProto, &body);
267 
268         ModifiedStructMachinery machinery;
269         machinery.modifiedStruct                   = &modifiedStruct;
270         machinery.getConverter(config.convertType) = funcDef;
271 
272         outMachineries.insert(originalStruct, machinery);
273     }
274 
pushPath(PathItem const & item)275     void pushPath(PathItem const &item)
276     {
277         pathItems.push_back(item);
278 
279         switch (item.type)
280         {
281             case PathItem::Type::Field:
282                 pushNamePath(item.field->name().data());
283                 break;
284 
285             case PathItem::Type::Index:
286                 pushNamePath(item.index);
287                 break;
288 
289             case PathItem::Type::FlattenArray:
290                 namePathSizes.push_back(namePath.size());
291                 break;
292         }
293     }
294 
popPath()295     void popPath()
296     {
297         ASSERT(!namePath.empty());
298         ASSERT(!namePathSizes.empty());
299         namePath.resize(namePathSizes.back());
300         namePathSizes.pop_back();
301 
302         ASSERT(!pathItems.empty());
303         pathItems.pop_back();
304     }
305 
finalize(const bool allowPadding)306     void finalize(const bool allowPadding)
307     {
308         ASSERT(!finalized);
309         finalized = true;
310         introducePacking();
311         ASSERT(metalLayoutTotal == Layout::Identity());
312         // Only pad substructs. We don't want to pad the structure that contains all the UBOs, only
313         // individual UBOs.
314         if (allowPadding)
315             introducePadding();
316     }
317 
addModifiedField(const TField & field,TType & newType,TLayoutBlockStorage storage,TLayoutMatrixPacking packing,const AddressSpace * addressSpace)318     void addModifiedField(const TField &field,
319                           TType &newType,
320                           TLayoutBlockStorage storage,
321                           TLayoutMatrixPacking packing,
322                           const AddressSpace *addressSpace)
323     {
324         TLayoutQualifier layoutQualifier = newType.getLayoutQualifier();
325         layoutQualifier.blockStorage     = storage;
326         layoutQualifier.matrixPacking    = packing;
327         newType.setLayoutQualifier(layoutQualifier);
328 
329         const ImmutableString pathName(namePath);
330         TField *modifiedField = new TField(&newType, pathName, field.line(), field.symbolType());
331         if (addressSpace)
332         {
333             symbolEnv.markAsPointer(*modifiedField, *addressSpace);
334         }
335         if (symbolEnv.isUBO(field))
336         {
337             symbolEnv.markAsUBO(*modifiedField);
338         }
339         modifiedFields.push_back(modifiedField);
340     }
341 
addConversion(const ConversionFunc & func)342     void addConversion(const ConversionFunc &func)
343     {
344         ASSERT(!modifiedFields.empty());
345         conversionInfos.push_back({func, nullptr, pathItems, modifiedFields.back()->name()});
346     }
347 
addConversion(const TFunction & func)348     void addConversion(const TFunction &func)
349     {
350         ASSERT(!modifiedFields.empty());
351         conversionInfos.push_back({{}, &func, pathItems, modifiedFields.back()->name()});
352     }
353 
hasPacking() const354     bool hasPacking() const { return containsPacked; }
355 
hasPadding() const356     bool hasPadding() const { return padFieldCount > 0; }
357 
recurse(const TStructure & structure,ModifiedStructMachinery & outMachinery,const bool isUBORecurse)358     bool recurse(const TStructure &structure,
359                  ModifiedStructMachinery &outMachinery,
360                  const bool isUBORecurse)
361     {
362         const ModifiedStructMachinery *m = outMachineries.find(structure);
363         if (m == nullptr)
364         {
365             TranslatorMetalReflection *reflection = mtl::getTranslatorMetalReflection(&mCompiler);
366             reflection->addOriginalName(structure.uniqueId().get(), structure.name().data());
367             const Name name = idGen.createNewName(structure.name().data());
368             if (!TryCreateModifiedStruct(mCompiler, symbolEnv, idGen, config, structure, name,
369                                          outMachineries, isUBORecurse, true))
370             {
371                 return false;
372             }
373             m = outMachineries.find(structure);
374             ASSERT(m);
375         }
376         outMachinery = *m;
377         return true;
378     }
379 
getIsUBO() const380     bool getIsUBO() const { return isUBO; }
381 
382   private:
addPadding(size_t padAmount,bool updateLayout)383     void addPadding(size_t padAmount, bool updateLayout)
384     {
385         if (padAmount == 0)
386         {
387             return;
388         }
389 
390         const size_t begin = modifiedFields.size();
391 
392         // Iteratively adding in scalar or vector padding because some struct types will not
393         // allow matrix or array members.
394         while (padAmount > 0)
395         {
396             TType *padType;
397             if (padAmount >= 16)
398             {
399                 padAmount -= 16;
400                 padType = new TType(TBasicType::EbtFloat, 4);
401             }
402             else if (padAmount >= 8)
403             {
404                 padAmount -= 8;
405                 padType = new TType(TBasicType::EbtFloat, 2);
406             }
407             else if (padAmount >= 4)
408             {
409                 padAmount -= 4;
410                 padType = new TType(TBasicType::EbtFloat);
411             }
412             else if (padAmount >= 2)
413             {
414                 padAmount -= 2;
415                 padType = new TType(TBasicType::EbtBool, 2);
416             }
417             else
418             {
419                 ASSERT(padAmount == 1);
420                 padAmount -= 1;
421                 padType = new TType(TBasicType::EbtBool);
422             }
423 
424             if (padType->getBasicType() != EbtBool)
425             {
426                 padType->setPrecision(EbpLow);
427             }
428 
429             if (updateLayout)
430             {
431                 metalLayoutTotal += MetalLayoutOf(*padType);
432             }
433 
434             const Name name = idGen.createNewName("pad");
435             modifiedFields.push_back(
436                 new TField(padType, name.rawName(), kNoSourceLoc, name.symbolType()));
437             ++padFieldCount;
438         }
439 
440         std::reverse(modifiedFields.begin() + begin, modifiedFields.end());
441     }
442 
introducePacking()443     void introducePacking()
444     {
445         if (!config.allowPacking)
446         {
447             return;
448         }
449 
450         auto setUnpackedStorage = [](TType &type) {
451             TLayoutBlockStorage storage = type.getLayoutQualifier().blockStorage;
452             switch (storage)
453             {
454                 case TLayoutBlockStorage::EbsShared:
455                     storage = TLayoutBlockStorage::EbsStd140;
456                     break;
457                 case TLayoutBlockStorage::EbsPacked:
458                     storage = TLayoutBlockStorage::EbsStd430;
459                     break;
460                 case TLayoutBlockStorage::EbsStd140:
461                 case TLayoutBlockStorage::EbsStd430:
462                 case TLayoutBlockStorage::EbsUnspecified:
463                     break;
464             }
465             SetBlockStorage(type, storage);
466         };
467 
468         Layout glslLayoutTotal = Layout::Identity();
469         const size_t size      = modifiedFields.size();
470 
471         for (size_t i = 0; i < size; ++i)
472         {
473             TField &curr           = *modifiedFields[i];
474             TType &currType        = *curr.type();
475             const bool canBePacked = CanBePacked(currType);
476 
477             auto dontPack = [&]() {
478                 if (canBePacked)
479                 {
480                     setUnpackedStorage(currType);
481                 }
482                 glslLayoutTotal += GlslLayoutOf(currType);
483             };
484 
485             if (!CanBePacked(currType))
486             {
487                 dontPack();
488                 continue;
489             }
490 
491             const Layout packedGlslLayout           = GlslLayoutOf(currType);
492             const TLayoutBlockStorage packedStorage = currType.getLayoutQualifier().blockStorage;
493             setUnpackedStorage(currType);
494             const Layout unpackedGlslLayout = GlslLayoutOf(currType);
495             SetBlockStorage(currType, packedStorage);
496 
497             ASSERT(packedGlslLayout.sizeOf <= unpackedGlslLayout.sizeOf);
498             if (packedGlslLayout.sizeOf == unpackedGlslLayout.sizeOf)
499             {
500                 dontPack();
501                 continue;
502             }
503 
504             const size_t j = i + 1;
505             if (j == size)
506             {
507                 dontPack();
508                 break;
509             }
510 
511             const size_t pad            = unpackedGlslLayout.sizeOf - packedGlslLayout.sizeOf;
512             const TField &next          = *modifiedFields[j];
513             const Layout nextGlslLayout = GlslLayoutOf(*next.type());
514 
515             if (pad < nextGlslLayout.sizeOf)
516             {
517                 dontPack();
518                 continue;
519             }
520 
521             symbolEnv.markAsPacked(curr);
522             glslLayoutTotal += packedGlslLayout;
523             containsPacked = true;
524         }
525     }
526 
introducePadding()527     void introducePadding()
528     {
529         if (!config.allowPadding)
530         {
531             return;
532         }
533 
534         MetalLayoutOfConfig layoutConfig;
535         layoutConfig.disablePacking             = !config.allowPacking;
536         layoutConfig.assumeStructsAreTailPadded = true;
537 
538         TFieldList fields = std::move(modifiedFields);
539         ASSERT(!fields.empty());  // GLSL requires at least one member.
540 
541         const TField *const first = fields.front();
542 
543         for (TField *field : fields)
544         {
545             const TType &type = *field->type();
546 
547             const Layout glslLayout  = GlslLayoutOf(type);
548             const Layout metalLayout = MetalLayoutOf(type, layoutConfig);
549 
550             size_t prePadAmount = 0;
551             if (glslLayout.alignOf > metalLayout.alignOf && field != first)
552             {
553                 const size_t prePaddedSize = metalLayoutTotal.sizeOf;
554                 metalLayoutTotal.requireAlignment(glslLayout.alignOf, true);
555                 const size_t paddedSize = metalLayoutTotal.sizeOf;
556                 prePadAmount            = paddedSize - prePaddedSize;
557                 metalLayoutTotal += metalLayout;
558                 addPadding(prePadAmount, false);  // Note: requireAlignment() already updated layout
559             }
560             else
561             {
562                 metalLayoutTotal += metalLayout;
563             }
564 
565             modifiedFields.push_back(field);
566 
567             if (glslLayout.sizeOf > metalLayout.sizeOf && field != fields.back())
568             {
569                 const bool updateLayout = true;  // XXX: Correct?
570                 const size_t padAmount  = glslLayout.sizeOf - metalLayout.sizeOf;
571                 addPadding(padAmount, updateLayout);
572             }
573         }
574     }
575 
pushNamePath(const char * extra)576     void pushNamePath(const char *extra)
577     {
578         ASSERT(extra && *extra != '\0');
579         namePathSizes.push_back(namePath.size());
580         const char *p = extra;
581         if (namePath.empty())
582         {
583             namePath = p;
584             return;
585         }
586         while (*p == '_')
587         {
588             ++p;
589         }
590         if (*p == '\0')
591         {
592             p = "x";
593         }
594         if (namePath.back() != '_')
595         {
596             namePath += '_';
597         }
598         namePath += p;
599     }
600 
pushNamePath(unsigned extra)601     void pushNamePath(unsigned extra)
602     {
603         char buffer[std::numeric_limits<unsigned>::digits10 + 1];
604         sprintf(buffer, "%u", extra);
605         pushNamePath(buffer);
606     }
607 
608   public:
609     TCompiler &mCompiler;
610     const ModifyStructConfig &config;
611     SymbolEnv &symbolEnv;
612 
613   private:
614     TFieldList &modifiedFields;
615     Layout metalLayoutTotal = Layout::Identity();
616     size_t padFieldCount    = 0;
617     bool containsPacked     = false;
618     bool finalized          = false;
619 
620     std::vector<PathItem> pathItems;
621 
622     std::vector<size_t> namePathSizes;
623     std::string namePath;
624 
625     std::vector<ConversionInfo> conversionInfos;
626     TSymbolTable &symbolTable;
627     IdGen &idGen;
628     ModifiedStructMachineries &outMachineries;
629     const bool isUBO;
630 };
631 
632 ////////////////////////////////////////////////////////////////////////////////
633 
634 using ModifyFunc = bool (*)(ConvertStructState &state,
635                             const TField &field,
636                             const TLayoutBlockStorage storage,
637                             const TLayoutMatrixPacking packing);
638 
639 bool ModifyRecursive(ConvertStructState &state,
640                      const TField &field,
641                      const TLayoutBlockStorage storage,
642                      const TLayoutMatrixPacking packing);
643 
IdentityModify(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)644 bool IdentityModify(ConvertStructState &state,
645                     const TField &field,
646                     const TLayoutBlockStorage storage,
647                     const TLayoutMatrixPacking packing)
648 {
649     const TType &type = *field.type();
650     state.addModifiedField(field, CloneType(type), storage, packing, nullptr);
651     state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
652         return Access{o, m};
653     });
654     return false;
655 }
656 
InlineStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)657 bool InlineStruct(ConvertStructState &state,
658                   const TField &field,
659                   const TLayoutBlockStorage storage,
660                   const TLayoutMatrixPacking packing)
661 {
662     const TType &type              = *field.type();
663     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
664     if (!substructure)
665     {
666         return false;
667     }
668     if (type.isArray())
669     {
670         return false;
671     }
672     if (!state.config.inlineStruct(field))
673     {
674         return false;
675     }
676 
677     const TFieldList &subfields = substructure->fields();
678     for (const TField *subfield : subfields)
679     {
680         const TType &subtype                  = *subfield->type();
681         const TLayoutBlockStorage substorage  = Overlay(storage, subtype);
682         const TLayoutMatrixPacking subpacking = Overlay(packing, subtype);
683         ModifyRecursive(state, *subfield, substorage, subpacking);
684     }
685 
686     return true;
687 }
688 
RecurseStruct(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)689 bool RecurseStruct(ConvertStructState &state,
690                    const TField &field,
691                    const TLayoutBlockStorage storage,
692                    const TLayoutMatrixPacking packing)
693 {
694     const TType &type              = *field.type();
695     const TStructure *substructure = state.symbolEnv.remap(type.getStruct());
696     if (!substructure)
697     {
698         return false;
699     }
700     if (!state.config.recurseStruct(field))
701     {
702         return false;
703     }
704 
705     ModifiedStructMachinery machinery;
706     if (!state.recurse(*substructure, machinery, state.getIsUBO()))
707     {
708         return false;
709     }
710 
711     TType &newType = *new TType(machinery.modifiedStruct, false);
712     if (type.isArray())
713     {
714         newType.makeArrays(type.getArraySizes());
715     }
716 
717     TIntermFunctionDefinition *converter = machinery.getConverter(state.config.convertType);
718     ASSERT(converter);
719 
720     state.addModifiedField(field, newType, storage, packing, state.symbolEnv.isPointer(field));
721     if (state.symbolEnv.isPointer(field))
722     {
723         state.symbolEnv.removePointer(field);
724     }
725     state.addConversion(*converter->getFunction());
726 
727     return true;
728 }
729 
SplitMatrixColumns(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)730 bool SplitMatrixColumns(ConvertStructState &state,
731                         const TField &field,
732                         const TLayoutBlockStorage storage,
733                         const TLayoutMatrixPacking packing)
734 {
735     const TType &type = *field.type();
736     if (!type.isMatrix())
737     {
738         return false;
739     }
740 
741     if (!state.config.splitMatrixColumns(field))
742     {
743         return false;
744     }
745 
746     const int cols = type.getCols();
747     TType &rowType = DropColumns(type);
748 
749     for (int c = 0; c < cols; ++c)
750     {
751         state.pushPath(c);
752 
753         state.addModifiedField(field, rowType, storage, packing, state.symbolEnv.isPointer(field));
754         if (state.symbolEnv.isPointer(field))
755         {
756             state.symbolEnv.removePointer(field);
757         }
758         state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
759             return Access{o, m};
760         });
761 
762         state.popPath();
763     }
764 
765     return true;
766 }
767 
SaturateMatrixRows(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)768 bool SaturateMatrixRows(ConvertStructState &state,
769                         const TField &field,
770                         const TLayoutBlockStorage storage,
771                         const TLayoutMatrixPacking packing)
772 {
773     const TType &type = *field.type();
774     if (!type.isMatrix())
775     {
776         return false;
777     }
778     const bool isRowMajor = type.getLayoutQualifier().matrixPacking == EmpRowMajor;
779     const int rows        = type.getRows();
780     const int saturation  = state.config.saturateMatrixRows(field);
781     if (saturation <= rows && !isRowMajor)
782     {
783         return false;
784     }
785 
786     const int cols = type.getCols();
787     TType &satType = SetMatrixRowDim(type, saturation);
788     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
789     if (state.symbolEnv.isPointer(field))
790     {
791         state.symbolEnv.removePointer(field);
792     }
793 
794     for (int c = 0; c < cols; ++c)
795     {
796         for (int r = 0; r < rows; ++r)
797         {
798             state.addConversion([=](Access::Env &, OriginalAccess &o, ModifiedAccess &m) {
799                 int firstModifiedIndex  = isRowMajor ? r : c;
800                 int secondModifiedIndex = isRowMajor ? c : r;
801                 auto &o_                = AccessIndex(AccessIndex(o, c), r);
802                 auto &m_ = AccessIndex(AccessIndex(m, firstModifiedIndex), secondModifiedIndex);
803                 return Access{o_, m_};
804             });
805         }
806     }
807 
808     return true;
809 }
810 
TestBoolToUint(ConvertStructState & state,const TField & field)811 bool TestBoolToUint(ConvertStructState &state, const TField &field)
812 {
813     if (field.type()->getBasicType() != TBasicType::EbtBool)
814     {
815         return false;
816     }
817     if (!state.config.promoteBoolToUint(field))
818     {
819         return false;
820     }
821     return true;
822 }
823 
ConvertBoolToUint(ConvertType convertType,OriginalAccess & o,ModifiedAccess & m)824 Access ConvertBoolToUint(ConvertType convertType, OriginalAccess &o, ModifiedAccess &m)
825 {
826     auto coerce = [](TIntermTyped &to, TIntermTyped &from) -> TIntermTyped & {
827         return *TIntermAggregate::CreateConstructor(to.getType(), new TIntermSequence{&from});
828     };
829     switch (convertType)
830     {
831         case ConvertType::OriginalToModified:
832             return Access{coerce(m, o), m};
833         case ConvertType::ModifiedToOriginal:
834             return Access{o, coerce(o, m)};
835     }
836 }
837 
SaturateScalarOrVectorCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing,const bool array)838 bool SaturateScalarOrVectorCommon(ConvertStructState &state,
839                                   const TField &field,
840                                   const TLayoutBlockStorage storage,
841                                   const TLayoutMatrixPacking packing,
842                                   const bool array)
843 {
844     const TType &type = *field.type();
845     if (type.isArray() != array)
846     {
847         return false;
848     }
849     if (!((type.isRank0() && HasScalarBasicType(type)) || type.isVector()))
850     {
851         return false;
852     }
853     const auto saturator =
854         array ? state.config.saturateScalarOrVectorArrays : state.config.saturateScalarOrVector;
855     const int dim        = type.getNominalSize();
856     const int saturation = saturator(field);
857     if (saturation <= dim)
858     {
859         return false;
860     }
861 
862     TType &satType        = SetVectorDim(type, saturation);
863     const bool boolToUint = TestBoolToUint(state, field);
864     if (boolToUint)
865     {
866         satType.setBasicType(TBasicType::EbtUInt);
867     }
868     state.addModifiedField(field, satType, storage, packing, state.symbolEnv.isPointer(field));
869     if (state.symbolEnv.isPointer(field))
870     {
871         state.symbolEnv.removePointer(field);
872     }
873 
874     for (int d = 0; d < dim; ++d)
875     {
876         state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
877             auto &o_ = dim > 1 ? AccessIndex(o, d) : o;
878             auto &m_ = AccessIndex(m, d);
879             if (boolToUint)
880             {
881                 return ConvertBoolToUint(env.type, o_, m_);
882             }
883             else
884             {
885                 return Access{o_, m_};
886             }
887         });
888     }
889 
890     return true;
891 }
892 
SaturateScalarOrVectorArrays(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)893 bool SaturateScalarOrVectorArrays(ConvertStructState &state,
894                                   const TField &field,
895                                   const TLayoutBlockStorage storage,
896                                   const TLayoutMatrixPacking packing)
897 {
898     return SaturateScalarOrVectorCommon(state, field, storage, packing, true);
899 }
900 
SaturateScalarOrVector(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)901 bool SaturateScalarOrVector(ConvertStructState &state,
902                             const TField &field,
903                             const TLayoutBlockStorage storage,
904                             const TLayoutMatrixPacking packing)
905 {
906     return SaturateScalarOrVectorCommon(state, field, storage, packing, false);
907 }
908 
PromoteBoolToUint(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)909 bool PromoteBoolToUint(ConvertStructState &state,
910                        const TField &field,
911                        const TLayoutBlockStorage storage,
912                        const TLayoutMatrixPacking packing)
913 {
914     if (!TestBoolToUint(state, field))
915     {
916         return false;
917     }
918 
919     auto &promotedType = CloneType(*field.type());
920     promotedType.setBasicType(TBasicType::EbtUInt);
921     state.addModifiedField(field, promotedType, storage, packing, state.symbolEnv.isPointer(field));
922     if (state.symbolEnv.isPointer(field))
923     {
924         state.symbolEnv.removePointer(field);
925     }
926 
927     state.addConversion([=](Access::Env &env, OriginalAccess &o, ModifiedAccess &m) {
928         return ConvertBoolToUint(env.type, o, m);
929     });
930 
931     return true;
932 }
933 
ModifyCommon(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)934 bool ModifyCommon(ConvertStructState &state,
935                   const TField &field,
936                   const TLayoutBlockStorage storage,
937                   const TLayoutMatrixPacking packing)
938 {
939     ModifyFunc funcs[] = {
940         InlineStruct,                  //
941         RecurseStruct,                 //
942         SplitMatrixColumns,            //
943         SaturateMatrixRows,            //
944         SaturateScalarOrVectorArrays,  //
945         SaturateScalarOrVector,        //
946         PromoteBoolToUint,             //
947     };
948 
949     for (ModifyFunc func : funcs)
950     {
951         if (func(state, field, storage, packing))
952         {
953             return true;
954         }
955     }
956 
957     return IdentityModify(state, field, storage, packing);
958 }
959 
InlineArray(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)960 bool InlineArray(ConvertStructState &state,
961                  const TField &field,
962                  const TLayoutBlockStorage storage,
963                  const TLayoutMatrixPacking packing)
964 {
965     const TType &type = *field.type();
966     if (!type.isArray())
967     {
968         return false;
969     }
970     if (!state.config.inlineArray(field))
971     {
972         return false;
973     }
974 
975     const unsigned volume = type.getArraySizeProduct();
976     const bool isMultiDim = type.isArrayOfArrays();
977 
978     auto &innermostType = InnermostType(type);
979 
980     if (isMultiDim)
981     {
982         state.pushPath(FlattenArray());
983     }
984 
985     for (unsigned i = 0; i < volume; ++i)
986     {
987         state.pushPath(i);
988         TType setType(innermostType);
989         if (setType.getLayoutQualifier().locationsSpecified)
990         {
991             TLayoutQualifier qualifier(innermostType.getLayoutQualifier());
992             qualifier.location           = innermostType.getLayoutQualifier().location + i;
993             qualifier.locationsSpecified = 1;
994             setType.setLayoutQualifier(qualifier);
995         }
996         const TField innermostField(&setType, field.name(), field.line(), field.symbolType());
997         ModifyCommon(state, innermostField, storage, packing);
998         state.popPath();
999     }
1000 
1001     if (isMultiDim)
1002     {
1003         state.popPath();
1004     }
1005 
1006     return true;
1007 }
1008 
ModifyRecursive(ConvertStructState & state,const TField & field,const TLayoutBlockStorage storage,const TLayoutMatrixPacking packing)1009 bool ModifyRecursive(ConvertStructState &state,
1010                      const TField &field,
1011                      const TLayoutBlockStorage storage,
1012                      const TLayoutMatrixPacking packing)
1013 {
1014     state.pushPath(field);
1015 
1016     bool modified;
1017     if (InlineArray(state, field, storage, packing))
1018     {
1019         modified = true;
1020     }
1021     else
1022     {
1023         modified = ModifyCommon(state, field, storage, packing);
1024     }
1025 
1026     state.popPath();
1027 
1028     return modified;
1029 }
1030 
1031 }  // anonymous namespace
1032 
1033 ////////////////////////////////////////////////////////////////////////////////
1034 
TryCreateModifiedStruct(TCompiler & compiler,SymbolEnv & symbolEnv,IdGen & idGen,const ModifyStructConfig & config,const TStructure & originalStruct,const Name & modifiedStructName,ModifiedStructMachineries & outMachineries,const bool isUBO,const bool allowPadding)1035 bool sh::TryCreateModifiedStruct(TCompiler &compiler,
1036                                  SymbolEnv &symbolEnv,
1037                                  IdGen &idGen,
1038                                  const ModifyStructConfig &config,
1039                                  const TStructure &originalStruct,
1040                                  const Name &modifiedStructName,
1041                                  ModifiedStructMachineries &outMachineries,
1042                                  const bool isUBO,
1043                                  const bool allowPadding)
1044 {
1045     ConvertStructState state(compiler, symbolEnv, idGen, config, outMachineries, isUBO);
1046     size_t identicalFieldCount = 0;
1047 
1048     const TFieldList &originalFields = originalStruct.fields();
1049     for (TField *originalField : originalFields)
1050     {
1051         const TType &originalType          = *originalField->type();
1052         const TLayoutBlockStorage storage  = Overlay(config.initialBlockStorage, originalType);
1053         const TLayoutMatrixPacking packing = Overlay(config.initialMatrixPacking, originalType);
1054         if (!ModifyRecursive(state, *originalField, storage, packing))
1055         {
1056             ++identicalFieldCount;
1057         }
1058     }
1059 
1060     state.finalize(allowPadding);
1061 
1062     if (identicalFieldCount == originalFields.size() && !state.hasPacking() &&
1063         !state.hasPadding() && !isUBO)
1064     {
1065         return false;
1066     }
1067 
1068     state.publish(originalStruct, modifiedStructName);
1069 
1070     return true;
1071 }
1072