• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2017, The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "module.h"
18 
19 #include <set>
20 
21 #include "builder.h"
22 #include "core_defs.h"
23 #include "instructions.h"
24 #include "types_generated.h"
25 #include "word_stream.h"
26 
27 namespace android {
28 namespace spirit {
29 
30 Module *Module::mInstance = nullptr;
31 
getCurrentModule()32 Module *Module::getCurrentModule() {
33   if (mInstance == nullptr) {
34     return mInstance = new Module();
35   }
36   return mInstance;
37 }
38 
Module()39 Module::Module()
40     : mNextId(1), mCapabilitiesDeleter(mCapabilities),
41       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
42       mEntryPointInstsDeleter(mEntryPointInsts),
43       mExecutionModesDeleter(mExecutionModes),
44       mEntryPointsDeleter(mEntryPoints),
45       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
46   mInstance = this;
47 }
48 
Module(Builder * b)49 Module::Module(Builder *b)
50     : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
51       mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
52       mEntryPointInstsDeleter(mEntryPointInsts),
53       mExecutionModesDeleter(mExecutionModes),
54       mEntryPointsDeleter(mEntryPoints),
55       mFunctionDefinitionsDeleter(mFunctionDefinitions) {
56   mInstance = this;
57 }
58 
resolveIds()59 bool Module::resolveIds() {
60   auto &table = mIdTable;
61 
62   std::unique_ptr<IVisitor> v0(
63       CreateInstructionVisitor([&table](Instruction *inst) {
64         if (inst->hasResult()) {
65           table.insert(std::make_pair(inst->getId(), inst));
66         }
67       }));
68   v0->visit(this);
69 
70   mNextId = mIdTable.rbegin()->first + 1;
71 
72   int err = 0;
73   std::unique_ptr<IVisitor> v(
74       CreateInstructionVisitor([&table, &err](Instruction *inst) {
75         for (auto ref : inst->getAllIdRefs()) {
76           if (ref) {
77             auto it = table.find(ref->mId);
78             if (it != table.end()) {
79               ref->mInstruction = it->second;
80             } else {
81               std::cout << "Found no instruction for id " << ref->mId
82                         << std::endl;
83               err++;
84             }
85           }
86         }
87       }));
88   v->visit(this);
89   return err == 0;
90 }
91 
DeserializeInternal(InputWordStream & IS)92 bool Module::DeserializeInternal(InputWordStream &IS) {
93   if (IS.empty()) {
94     return false;
95   }
96 
97   IS >> &mMagicNumber;
98   if (mMagicNumber != 0x07230203) {
99     errs() << "Wrong Magic Number: " << mMagicNumber;
100     return false;
101   }
102 
103   if (IS.empty()) {
104     return false;
105   }
106 
107   IS >> &mVersion.mWord;
108   if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
109     return false;
110   }
111 
112   if (IS.empty()) {
113     return false;
114   }
115 
116   IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
117 
118   DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
119   DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
120   DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
121 
122   mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
123   if (!mMemoryModel) {
124     errs() << "Missing memory model specification.\n";
125     return false;
126   }
127 
128   DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
129   DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
130   for (auto entry : mEntryPoints) {
131     mEntryPointInsts.push_back(entry->getInstruction());
132     for (auto mode : mExecutionModes) {
133       entry->applyExecutionMode(mode);
134     }
135   }
136 
137   mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
138   mAnnotations.reset(Deserialize<AnnotationSection>(IS));
139   mGlobals.reset(Deserialize<GlobalSection>(IS));
140 
141   DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
142 
143   if (mFunctionDefinitions.empty()) {
144     errs() << "Missing function definitions.\n";
145     for (int i = 0; i < 4; i++) {
146       uint32_t w;
147       IS >> &w;
148       std::cout << std::hex << w << " ";
149     }
150     std::cout << std::endl;
151     return false;
152   }
153 
154   return true;
155 }
156 
initialize()157 void Module::initialize() {
158   mMagicNumber = 0x07230203;
159   mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
160   mGeneratorMagicNumber = 0x00070000;
161   mBound = 0;
162   mReserved = 0;
163   mAnnotations.reset(new AnnotationSection());
164 }
165 
SerializeHeader(OutputWordStream & OS) const166 void Module::SerializeHeader(OutputWordStream &OS) const {
167   OS << mMagicNumber;
168   OS << mVersion.mWord << mGeneratorMagicNumber;
169   if (mBound == 0) {
170     OS << mIdTable.end()->first + 1;
171   } else {
172     OS << std::max(mBound, mNextId);
173   }
174   OS << mReserved;
175 }
176 
Serialize(OutputWordStream & OS) const177 void Module::Serialize(OutputWordStream &OS) const {
178   SerializeHeader(OS);
179   Entity::Serialize(OS);
180 }
181 
addCapability(Capability cap)182 Module *Module::addCapability(Capability cap) {
183   mCapabilities.push_back(mBuilder->MakeCapability(cap));
184   return this;
185 }
186 
setMemoryModel(AddressingModel am,MemoryModel mm)187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
188   mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
189   return this;
190 }
191 
addExtInstImport(const char * extName)192 Module *Module::addExtInstImport(const char *extName) {
193   ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
194   mExtInstImports.push_back(extInst);
195   if (strcmp(extName, "GLSL.std.450") == 0) {
196     mGLExt = extInst;
197   }
198   return this;
199 }
200 
addSource(SourceLanguage lang,int version)201 Module *Module::addSource(SourceLanguage lang, int version) {
202   if (!mDebugInfo) {
203     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
204   }
205   mDebugInfo->addSource(lang, version);
206   return this;
207 }
208 
addSourceExtension(const char * ext)209 Module *Module::addSourceExtension(const char *ext) {
210   if (!mDebugInfo) {
211     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
212   }
213   mDebugInfo->addSourceExtension(ext);
214   return this;
215 }
216 
addString(const char * str)217 Module *Module::addString(const char *str) {
218   if (!mDebugInfo) {
219     mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
220   }
221   mDebugInfo->addString(str);
222   return this;
223 }
224 
addEntryPoint(EntryPointDefinition * entry)225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
226   mEntryPoints.push_back(entry);
227   auto newModes = entry->getExecutionModes();
228   mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
229                          newModes.end());
230   return this;
231 }
232 
findStringOfPrefix(const char * prefix) const233 const std::string Module::findStringOfPrefix(const char *prefix) const {
234   if (!mDebugInfo) {
235     return std::string();
236   }
237   return mDebugInfo->findStringOfPrefix(prefix);
238 }
239 
getGlobalSection()240 GlobalSection *Module::getGlobalSection() {
241   if (!mGlobals) {
242     mGlobals.reset(new GlobalSection());
243   }
244   return mGlobals.get();
245 }
246 
getConstant(TypeIntInst * type,int32_t value)247 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
248   return getGlobalSection()->getConstant(type, value);
249 }
250 
getConstant(TypeIntInst * type,uint32_t value)251 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
252   return getGlobalSection()->getConstant(type, value);
253 }
254 
getConstant(TypeFloatInst * type,float value)255 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
256   return getGlobalSection()->getConstant(type, value);
257 }
258 
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)259 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
260                                                     ConstantInst *components[],
261                                                     size_t width) {
262   return getGlobalSection()->getConstantComposite(type, components, width);
263 }
264 
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2)265 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
266                                                     ConstantInst *comp0,
267                                                     ConstantInst *comp1,
268                                                     ConstantInst *comp2) {
269   // TODO: verify that component types are the same and consistent with the
270   // resulting vector type
271   ConstantInst *comps[] = {comp0, comp1, comp2};
272   return getConstantComposite(type, comps, 3);
273 }
274 
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2,ConstantInst * comp3)275 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
276                                                     ConstantInst *comp0,
277                                                     ConstantInst *comp1,
278                                                     ConstantInst *comp2,
279                                                     ConstantInst *comp3) {
280   // TODO: verify that component types are the same and consistent with the
281   // resulting vector type
282   ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
283   return getConstantComposite(type, comps, 4);
284 }
285 
getVoidType()286 TypeVoidInst *Module::getVoidType() {
287   return getGlobalSection()->getVoidType();
288 }
289 
getIntType(int bits,bool isSigned)290 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
291   return getGlobalSection()->getIntType(bits, isSigned);
292 }
293 
getUnsignedIntType(int bits)294 TypeIntInst *Module::getUnsignedIntType(int bits) {
295   return getIntType(bits, false);
296 }
297 
getFloatType(int bits)298 TypeFloatInst *Module::getFloatType(int bits) {
299   return getGlobalSection()->getFloatType(bits);
300 }
301 
getVectorType(Instruction * componentType,int width)302 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
303   return getGlobalSection()->getVectorType(componentType, width);
304 }
305 
getPointerType(StorageClass storage,Instruction * pointeeType)306 TypePointerInst *Module::getPointerType(StorageClass storage,
307                                         Instruction *pointeeType) {
308   return getGlobalSection()->getPointerType(storage, pointeeType);
309 }
310 
getRuntimeArrayType(Instruction * elementType)311 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
312   return getGlobalSection()->getRuntimeArrayType(elementType);
313 }
314 
getStructType(Instruction * fieldType[],int numField)315 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
316   return getGlobalSection()->getStructType(fieldType, numField);
317 }
318 
getStructType(Instruction * fieldType)319 TypeStructInst *Module::getStructType(Instruction *fieldType) {
320   return getStructType(&fieldType, 1);
321 }
322 
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)323 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
324                                           Instruction *const argType[],
325                                           size_t numArg) {
326   return getGlobalSection()->getFunctionType(retType, argType, numArg);
327 }
328 
329 TypeFunctionInst *
getFunctionType(Instruction * retType,const std::vector<Instruction * > & argTypes)330 Module::getFunctionType(Instruction *retType,
331                         const std::vector<Instruction *> &argTypes) {
332   return getGlobalSection()->getFunctionType(retType, argTypes.data(),
333                                              argTypes.size());
334 }
335 
getSize(TypeVoidInst *)336 size_t Module::getSize(TypeVoidInst *) { return 0; }
337 
getSize(TypeIntInst * intTy)338 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
339 
getSize(TypeFloatInst * fpTy)340 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
341 
getSize(TypeVectorInst * vTy)342 size_t Module::getSize(TypeVectorInst *vTy) {
343   return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
344 }
345 
getSize(TypePointerInst *)346 size_t Module::getSize(TypePointerInst *) {
347   return 4; // TODO: or 8?
348 }
349 
getSize(TypeStructInst * structTy)350 size_t Module::getSize(TypeStructInst *structTy) {
351   size_t sz = 0;
352   for (auto ty : structTy->mOperand1) {
353     sz += getSize(ty.mInstruction);
354   }
355   return sz;
356 }
357 
getSize(TypeFunctionInst *)358 size_t Module::getSize(TypeFunctionInst *) {
359   return 4; // TODO: or 8? Is this just the size of a pointer?
360 }
361 
getSize(Instruction * inst)362 size_t Module::getSize(Instruction *inst) {
363   switch (inst->getOpCode()) {
364   case OpTypeVoid:
365     return getSize(static_cast<TypeVoidInst *>(inst));
366   case OpTypeInt:
367     return getSize(static_cast<TypeIntInst *>(inst));
368   case OpTypeFloat:
369     return getSize(static_cast<TypeFloatInst *>(inst));
370   case OpTypeVector:
371     return getSize(static_cast<TypeVectorInst *>(inst));
372   case OpTypeStruct:
373     return getSize(static_cast<TypeStructInst *>(inst));
374   case OpTypeFunction:
375     return getSize(static_cast<TypeFunctionInst *>(inst));
376   default:
377     return 0;
378   }
379 }
380 
addFunctionDefinition(FunctionDefinition * func)381 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
382   mFunctionDefinitions.push_back(func);
383   return this;
384 }
385 
lookupByName(const char * name) const386 Instruction *Module::lookupByName(const char *name) const {
387   return mDebugInfo->lookupByName(name);
388 }
389 
390 FunctionDefinition *
getFunctionDefinitionFromInstruction(FunctionInst * inst) const391 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
392   for (auto fdef : mFunctionDefinitions) {
393     if (fdef->getInstruction() == inst) {
394       return fdef;
395     }
396   }
397   return nullptr;
398 }
399 
400 FunctionDefinition *
lookupFunctionDefinitionByName(const char * name) const401 Module::lookupFunctionDefinitionByName(const char *name) const {
402   FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
403   return getFunctionDefinitionFromInstruction(inst);
404 }
405 
lookupNameByInstruction(const Instruction * inst) const406 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
407   return mDebugInfo->lookupNameByInstruction(inst);
408 }
409 
getInvocationId()410 VariableInst *Module::getInvocationId() {
411   return getGlobalSection()->getInvocationId();
412 }
413 
getNumWorkgroups()414 VariableInst *Module::getNumWorkgroups() {
415   return getGlobalSection()->getNumWorkgroups();
416 }
417 
addStructType(TypeStructInst * structType)418 Module *Module::addStructType(TypeStructInst *structType) {
419   getGlobalSection()->addStructType(structType);
420   return this;
421 }
422 
addVariable(VariableInst * var)423 Module *Module::addVariable(VariableInst *var) {
424   getGlobalSection()->addVariable(var);
425   return this;
426 }
427 
consolidateAnnotations()428 void Module::consolidateAnnotations() {
429   std::vector<Instruction *> annotations(mAnnotations->begin(),
430                                       mAnnotations->end());
431   std::unique_ptr<IVisitor> v(
432       CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
433         const auto &ann = inst->getAnnotations();
434         annotations.insert(annotations.end(), ann.begin(), ann.end());
435       }));
436   v->visit(this);
437   mAnnotations->clear();
438   mAnnotations->addAnnotations(annotations.begin(), annotations.end());
439 }
440 
EntryPointDefinition(Builder * builder,ExecutionModel execModel,FunctionDefinition * func,const char * name)441 EntryPointDefinition::EntryPointDefinition(Builder *builder,
442                                            ExecutionModel execModel,
443                                            FunctionDefinition *func,
444                                            const char *name)
445     : Entity(builder), mFunction(func->getInstruction()),
446       mExecutionModel(execModel) {
447   mName = strndup(name, strlen(name));
448   mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
449 }
450 
DeserializeInternal(InputWordStream & IS)451 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
452   if (IS.empty()) {
453     return false;
454   }
455 
456   if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
457     return true;
458   }
459 
460   return false;
461 }
462 
463 EntryPointDefinition *
applyExecutionMode(ExecutionModeInst * mode)464 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
465   if (mode->mOperand1.mInstruction == mFunction) {
466     addExecutionMode(mode);
467   }
468   return this;
469 }
470 
addToInterface(VariableInst * var)471 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
472   mInterface.push_back(var);
473   mEntryPointInst->mOperand4.push_back(var);
474   return this;
475 }
476 
setLocalSize(uint32_t width,uint32_t height,uint32_t depth)477 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
478                                                          uint32_t height,
479                                                          uint32_t depth) {
480   mLocalSize.mWidth = width;
481   mLocalSize.mHeight = height;
482   mLocalSize.mDepth = depth;
483 
484   auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
485   mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
486 
487   addExecutionMode(mode);
488 
489   return this;
490 }
491 
DeserializeInternal(InputWordStream & IS)492 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
493   while (true) {
494     if (auto str = Deserialize<StringInst>(IS)) {
495       mSources.push_back(str);
496     } else if (auto src = Deserialize<SourceInst>(IS)) {
497       mSources.push_back(src);
498     } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
499       mSources.push_back(srcExt);
500     } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
501       mSources.push_back(srcCont);
502     } else {
503       break;
504     }
505   }
506 
507   while (true) {
508     if (auto name = Deserialize<NameInst>(IS)) {
509       mNames.push_back(name);
510     } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
511       mNames.push_back(memName);
512     } else {
513       break;
514     }
515   }
516 
517   return true;
518 }
519 
addSource(SourceLanguage lang,int version)520 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
521                                               int version) {
522   SourceInst *source = mBuilder->MakeSource(lang, version);
523   mSources.push_back(source);
524   return this;
525 }
526 
addSourceExtension(const char * ext)527 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
528   SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
529   mSources.push_back(inst);
530   return this;
531 }
532 
addString(const char * str)533 DebugInfoSection *DebugInfoSection::addString(const char *str) {
534   StringInst *source = mBuilder->MakeString(str);
535   mSources.push_back(source);
536   return this;
537 }
538 
findStringOfPrefix(const char * prefix)539 std::string DebugInfoSection::findStringOfPrefix(const char *prefix) {
540   auto it = std::find_if(
541       mSources.begin(), mSources.end(), [prefix](Instruction *inst) -> bool {
542         if (inst->getOpCode() != OpString) {
543           return false;
544         }
545         const StringInst *strInst = static_cast<const StringInst *>(inst);
546         const std::string &str = strInst->mOperand1;
547         return str.find(prefix) == 0;
548       });
549   if (it == mSources.end()) {
550     return "";
551   }
552   StringInst *strInst = static_cast<StringInst *>(*it);
553   return strInst->mOperand1;
554 }
555 
lookupByName(const char * name) const556 Instruction *DebugInfoSection::lookupByName(const char *name) const {
557   for (auto inst : mNames) {
558     if (inst->getOpCode() == OpName) {
559       NameInst *nameInst = static_cast<NameInst *>(inst);
560       if (nameInst->mOperand2.compare(name) == 0) {
561         return nameInst->mOperand1.mInstruction;
562       }
563     }
564     // Ignore member names
565   }
566   return nullptr;
567 }
568 
569 const char *
lookupNameByInstruction(const Instruction * target) const570 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
571   for (auto inst : mNames) {
572     if (inst->getOpCode() == OpName) {
573       NameInst *nameInst = static_cast<NameInst *>(inst);
574       if (nameInst->mOperand1.mInstruction == target) {
575         return nameInst->mOperand2.c_str();
576       }
577     }
578     // Ignore member names
579   }
580   return nullptr;
581 }
582 
AnnotationSection()583 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
584 
AnnotationSection(Builder * b)585 AnnotationSection::AnnotationSection(Builder *b)
586     : Entity(b), mAnnotationsDeleter(mAnnotations) {}
587 
DeserializeInternal(InputWordStream & IS)588 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
589   while (true) {
590     if (auto decor = Deserialize<DecorateInst>(IS)) {
591       mAnnotations.push_back(decor);
592     } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
593       mAnnotations.push_back(decor);
594     } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
595       mAnnotations.push_back(decor);
596     } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
597       mAnnotations.push_back(decor);
598     } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
599       mAnnotations.push_back(decor);
600     } else {
601       break;
602     }
603   }
604   return true;
605 }
606 
GlobalSection()607 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
608 
GlobalSection(Builder * builder)609 GlobalSection::GlobalSection(Builder *builder)
610     : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
611 
612 namespace {
613 
614 template <typename T>
findOrCreate(std::function<bool (T *)> criteria,std::function<T * ()> factory,std::vector<Instruction * > * globals)615 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
616                 std::vector<Instruction *> *globals) {
617   T *derived;
618   for (auto inst : *globals) {
619     if (inst->getOpCode() == T::mOpCode) {
620       T *derived = static_cast<T *>(inst);
621       if (criteria(derived)) {
622         return derived;
623       }
624     }
625   }
626   derived = factory();
627   globals->push_back(derived);
628   return derived;
629 }
630 
631 } // anonymous namespace
632 
DeserializeInternal(InputWordStream & IS)633 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
634   while (true) {
635 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
636   if (auto typeInst = Deserialize<INST_CLASS>(IS)) {                           \
637     mGlobalDefs.push_back(typeInst);                                           \
638     continue;                                                                  \
639   }
640 #include "const_inst_dispatches_generated.h"
641 #include "type_inst_dispatches_generated.h"
642 #undef HANDLE_INSTRUCTION
643 
644     if (auto globalInst = Deserialize<VariableInst>(IS)) {
645       // Check if this is function scoped
646       if (globalInst->mOperand1 == StorageClass::Function) {
647         Module::errs() << "warning: Variable (id = " << globalInst->mResult;
648         Module::errs() << ") has function scope in global section.\n";
649         // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
650         // As a workaround, accept such SPIR-V code here, and fix it up later
651         // in the rs2spirv compiler by correcting the storage class.
652         // In a stricter deserializer, such code should be rejected, and we
653         // should return false here.
654       }
655       mGlobalDefs.push_back(globalInst);
656       continue;
657     }
658 
659     if (auto globalInst = Deserialize<UndefInst>(IS)) {
660       mGlobalDefs.push_back(globalInst);
661       continue;
662     }
663     break;
664   }
665   return true;
666 }
667 
getConstant(TypeIntInst * type,int32_t value)668 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
669   return findOrCreate<ConstantInst>(
670       [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
671       [=]() -> ConstantInst * {
672         LiteralContextDependentNumber cdn = {.intValue = value};
673         return mBuilder->MakeConstant(type, cdn);
674       },
675       &mGlobalDefs);
676 }
677 
getConstant(TypeIntInst * type,uint32_t value)678 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
679   return findOrCreate<ConstantInst>(
680       [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
681       [=]() -> ConstantInst * {
682         LiteralContextDependentNumber cdn = {.intValue = (int)value};
683         return mBuilder->MakeConstant(type, cdn);
684       },
685       &mGlobalDefs);
686 }
687 
getConstant(TypeFloatInst * type,float value)688 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
689   return findOrCreate<ConstantInst>(
690       [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
691       [=]() -> ConstantInst * {
692         LiteralContextDependentNumber cdn = {.floatValue = value};
693         return mBuilder->MakeConstant(type, cdn);
694       },
695       &mGlobalDefs);
696 }
697 
698 ConstantCompositeInst *
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)699 GlobalSection::getConstantComposite(TypeVectorInst *type,
700                                     ConstantInst *components[], size_t width) {
701   return findOrCreate<ConstantCompositeInst>(
702       [=](ConstantCompositeInst *c) {
703         if (c->mOperand1.size() != width) {
704           return false;
705         }
706         for (size_t i = 0; i < width; i++) {
707           if (c->mOperand1[i].mInstruction != components[i]) {
708             return false;
709           }
710         }
711         return true;
712       },
713       [=]() -> ConstantCompositeInst * {
714         ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
715         for (size_t i = 0; i < width; i++) {
716           c->mOperand1.push_back(components[i]);
717         }
718         return c;
719       },
720       &mGlobalDefs);
721 }
722 
getVoidType()723 TypeVoidInst *GlobalSection::getVoidType() {
724   return findOrCreate<TypeVoidInst>(
725       [=](TypeVoidInst *) -> bool { return true; },
726       [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
727       &mGlobalDefs);
728 }
729 
getIntType(int bits,bool isSigned)730 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
731   if (isSigned) {
732     switch (bits) {
733 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED)                                \
734   case BITS: {                                                                 \
735     return findOrCreate<TypeIntInst>(                                          \
736         [=](TypeIntInst *intTy) -> bool {                                      \
737           return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED;       \
738         },                                                                     \
739         [=]() -> TypeIntInst * {                                               \
740           return mBuilder->MakeTypeInt(BITS, SIGNED);                          \
741         },                                                                     \
742         &mGlobalDefs);                                                         \
743   }
744       HANDLE_INT_SIZE(Int, 8, 1);
745       HANDLE_INT_SIZE(Int, 16, 1);
746       HANDLE_INT_SIZE(Int, 32, 1);
747       HANDLE_INT_SIZE(Int, 64, 1);
748     default:
749       Module::errs() << "unexpected int type";
750     }
751   } else {
752     switch (bits) {
753       HANDLE_INT_SIZE(UInt, 8, 0);
754       HANDLE_INT_SIZE(UInt, 16, 0);
755       HANDLE_INT_SIZE(UInt, 32, 0);
756       HANDLE_INT_SIZE(UInt, 64, 0);
757     default:
758       Module::errs() << "unexpected int type";
759     }
760   }
761 #undef HANDLE_INT_SIZE
762   return nullptr;
763 }
764 
getFloatType(int bits)765 TypeFloatInst *GlobalSection::getFloatType(int bits) {
766   switch (bits) {
767 #define HANDLE_FLOAT_SIZE(BITS)                                                \
768   case BITS: {                                                                 \
769     return findOrCreate<TypeFloatInst>(                                        \
770         [=](TypeFloatInst *floatTy) -> bool {                                  \
771           return floatTy->mOperand1 == BITS;                                   \
772         },                                                                     \
773         [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); },    \
774         &mGlobalDefs);                                                         \
775   }
776     HANDLE_FLOAT_SIZE(16);
777     HANDLE_FLOAT_SIZE(32);
778     HANDLE_FLOAT_SIZE(64);
779   default:
780     Module::errs() << "unexpeced floating point type";
781   }
782 #undef HANDLE_FLOAT_SIZE
783   return nullptr;
784 }
785 
getVectorType(Instruction * componentType,int width)786 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
787                                              int width) {
788   // TODO: verify that componentType is basic numeric types
789 
790   return findOrCreate<TypeVectorInst>(
791       [=](TypeVectorInst *vecTy) -> bool {
792         return vecTy->mOperand1.mInstruction == componentType &&
793                vecTy->mOperand2 == width;
794       },
795       [=]() -> TypeVectorInst * {
796         return mBuilder->MakeTypeVector(componentType, width);
797       },
798       &mGlobalDefs);
799 }
800 
getPointerType(StorageClass storage,Instruction * pointeeType)801 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
802                                                Instruction *pointeeType) {
803   return findOrCreate<TypePointerInst>(
804       [=](TypePointerInst *type) -> bool {
805         return type->mOperand1 == storage &&
806                type->mOperand2.mInstruction == pointeeType;
807       },
808       [=]() -> TypePointerInst * {
809         return mBuilder->MakeTypePointer(storage, pointeeType);
810       },
811       &mGlobalDefs);
812 }
813 
814 TypeRuntimeArrayInst *
getRuntimeArrayType(Instruction * elemType)815 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
816   return findOrCreate<TypeRuntimeArrayInst>(
817       [=](TypeRuntimeArrayInst * /*type*/) -> bool {
818         // return type->mOperand1.mInstruction == elemType;
819         return false;
820       },
821       [=]() -> TypeRuntimeArrayInst * {
822         return mBuilder->MakeTypeRuntimeArray(elemType);
823       },
824       &mGlobalDefs);
825 }
826 
getStructType(Instruction * fieldType[],int numField)827 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
828                                              int numField) {
829   TypeStructInst *structTy = mBuilder->MakeTypeStruct();
830   for (int i = 0; i < numField; i++) {
831     structTy->mOperand1.push_back(fieldType[i]);
832   }
833   mGlobalDefs.push_back(structTy);
834   return structTy;
835 }
836 
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)837 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
838                                                  Instruction *const argType[],
839                                                  size_t numArg) {
840   return findOrCreate<TypeFunctionInst>(
841       [=](TypeFunctionInst *type) -> bool {
842         if (type->mOperand1.mInstruction != retType ||
843             type->mOperand2.size() != numArg) {
844           return false;
845         }
846         for (size_t i = 0; i < numArg; i++) {
847           if (type->mOperand2[i].mInstruction != argType[i]) {
848             return false;
849           }
850         }
851         return true;
852       },
853       [=]() -> TypeFunctionInst * {
854         TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
855         for (size_t i = 0; i < numArg; i++) {
856           funcTy->mOperand2.push_back(argType[i]);
857         }
858         return funcTy;
859       },
860       &mGlobalDefs);
861 }
862 
addStructType(TypeStructInst * structType)863 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
864   mGlobalDefs.push_back(structType);
865   return this;
866 }
867 
addVariable(VariableInst * var)868 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
869   mGlobalDefs.push_back(var);
870   return this;
871 }
872 
getInvocationId()873 VariableInst *GlobalSection::getInvocationId() {
874   if (mInvocationId) {
875     return mInvocationId.get();
876   }
877 
878   TypeIntInst *UIntTy = getIntType(32, false);
879   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
880   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
881 
882   VariableInst *InvocationId =
883       mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
884   InvocationId->decorate(Decoration::BuiltIn)
885       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
886 
887   mInvocationId.reset(InvocationId);
888 
889   return InvocationId;
890 }
891 
getNumWorkgroups()892 VariableInst *GlobalSection::getNumWorkgroups() {
893   if (mNumWorkgroups) {
894     return mNumWorkgroups.get();
895   }
896 
897   TypeIntInst *UIntTy = getIntType(32, false);
898   TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
899   TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
900 
901   VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
902   GNum->decorate(Decoration::BuiltIn)
903       ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
904 
905   mNumWorkgroups.reset(GNum);
906 
907   return GNum;
908 }
909 
DeserializeInternal(InputWordStream & IS)910 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
911   if (!Deserialize<FunctionInst>(IS)) {
912     return false;
913   }
914 
915   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
916 
917   if (!Deserialize<FunctionEndInst>(IS)) {
918     return false;
919   }
920 
921   return true;
922 }
923 
Deserialize(InputWordStream & IS)924 template <> Instruction *Deserialize(InputWordStream &IS) {
925   Instruction *inst;
926 
927   switch ((*IS) & 0xFFFF) {
928 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS)                                 \
929   case OPCODE:                                                                 \
930     inst = Deserialize<INST_CLASS>(IS);                                        \
931     break;
932 #include "instruction_dispatches_generated.h"
933 #undef HANDLE_INSTRUCTION
934   default:
935     Module::errs() << "unrecognized instruction";
936     inst = nullptr;
937   }
938 
939   return inst;
940 }
941 
DeserializeInternal(InputWordStream & IS)942 bool Block::DeserializeInternal(InputWordStream &IS) {
943   Instruction *inst;
944   while (((*IS) & 0xFFFF) != OpFunctionEnd &&
945          (inst = Deserialize<Instruction>(IS))) {
946     mInsts.push_back(inst);
947     if (inst->getOpCode() == OpBranch ||
948         inst->getOpCode() == OpBranchConditional ||
949         inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
950         inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
951         inst->getOpCode() == OpUnreachable) {
952       break;
953     }
954   }
955   return !mInsts.empty();
956 }
957 
FunctionDefinition()958 FunctionDefinition::FunctionDefinition()
959     : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
960 
FunctionDefinition(Builder * builder,FunctionInst * func,FunctionEndInst * end)961 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
962                                        FunctionEndInst *end)
963     : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
964       mBlocksDeleter(mBlocks) {}
965 
DeserializeInternal(InputWordStream & IS)966 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
967   mFunc.reset(Deserialize<FunctionInst>(IS));
968   if (!mFunc) {
969     return false;
970   }
971 
972   DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
973   DeserializeZeroOrMore<Block>(IS, mBlocks);
974 
975   mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
976   if (!mFuncEnd) {
977     return false;
978   }
979 
980   return true;
981 }
982 
getReturnType() const983 Instruction *FunctionDefinition::getReturnType() const {
984   return mFunc->mResultType.mInstruction;
985 }
986 
987 } // namespace spirit
988 } // namespace android
989