1 /*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "module.h"
18
19 #include <set>
20
21 #include "builder.h"
22 #include "core_defs.h"
23 #include "instructions.h"
24 #include "types_generated.h"
25 #include "word_stream.h"
26
27 namespace android {
28 namespace spirit {
29
30 Module *Module::mInstance = nullptr;
31
getCurrentModule()32 Module *Module::getCurrentModule() {
33 if (mInstance == nullptr) {
34 return mInstance = new Module();
35 }
36 return mInstance;
37 }
38
Module()39 Module::Module()
40 : mNextId(1), mCapabilitiesDeleter(mCapabilities),
41 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
42 mEntryPointInstsDeleter(mEntryPointInsts),
43 mExecutionModesDeleter(mExecutionModes),
44 mEntryPointsDeleter(mEntryPoints),
45 mFunctionDefinitionsDeleter(mFunctionDefinitions) {
46 mInstance = this;
47 }
48
Module(Builder * b)49 Module::Module(Builder *b)
50 : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
51 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
52 mEntryPointInstsDeleter(mEntryPointInsts),
53 mExecutionModesDeleter(mExecutionModes),
54 mEntryPointsDeleter(mEntryPoints),
55 mFunctionDefinitionsDeleter(mFunctionDefinitions) {
56 mInstance = this;
57 }
58
resolveIds()59 bool Module::resolveIds() {
60 auto &table = mIdTable;
61
62 std::unique_ptr<IVisitor> v0(
63 CreateInstructionVisitor([&table](Instruction *inst) {
64 if (inst->hasResult()) {
65 table.insert(std::make_pair(inst->getId(), inst));
66 }
67 }));
68 v0->visit(this);
69
70 mNextId = mIdTable.rbegin()->first + 1;
71
72 int err = 0;
73 std::unique_ptr<IVisitor> v(
74 CreateInstructionVisitor([&table, &err](Instruction *inst) {
75 for (auto ref : inst->getAllIdRefs()) {
76 if (ref) {
77 auto it = table.find(ref->mId);
78 if (it != table.end()) {
79 ref->mInstruction = it->second;
80 } else {
81 std::cout << "Found no instruction for id " << ref->mId
82 << std::endl;
83 err++;
84 }
85 }
86 }
87 }));
88 v->visit(this);
89 return err == 0;
90 }
91
DeserializeInternal(InputWordStream & IS)92 bool Module::DeserializeInternal(InputWordStream &IS) {
93 if (IS.empty()) {
94 return false;
95 }
96
97 IS >> &mMagicNumber;
98 if (mMagicNumber != 0x07230203) {
99 errs() << "Wrong Magic Number: " << mMagicNumber;
100 return false;
101 }
102
103 if (IS.empty()) {
104 return false;
105 }
106
107 IS >> &mVersion.mWord;
108 if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
109 return false;
110 }
111
112 if (IS.empty()) {
113 return false;
114 }
115
116 IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
117
118 DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
119 DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
120 DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
121
122 mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
123 if (!mMemoryModel) {
124 errs() << "Missing memory model specification.\n";
125 return false;
126 }
127
128 DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
129 DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
130 for (auto entry : mEntryPoints) {
131 mEntryPointInsts.push_back(entry->getInstruction());
132 for (auto mode : mExecutionModes) {
133 entry->applyExecutionMode(mode);
134 }
135 }
136
137 mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
138 mAnnotations.reset(Deserialize<AnnotationSection>(IS));
139 mGlobals.reset(Deserialize<GlobalSection>(IS));
140
141 DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
142
143 if (mFunctionDefinitions.empty()) {
144 errs() << "Missing function definitions.\n";
145 for (int i = 0; i < 4; i++) {
146 uint32_t w;
147 IS >> &w;
148 std::cout << std::hex << w << " ";
149 }
150 std::cout << std::endl;
151 return false;
152 }
153
154 return true;
155 }
156
initialize()157 void Module::initialize() {
158 mMagicNumber = 0x07230203;
159 mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
160 mGeneratorMagicNumber = 0x00070000;
161 mBound = 0;
162 mReserved = 0;
163 mAnnotations.reset(new AnnotationSection());
164 }
165
SerializeHeader(OutputWordStream & OS) const166 void Module::SerializeHeader(OutputWordStream &OS) const {
167 OS << mMagicNumber;
168 OS << mVersion.mWord << mGeneratorMagicNumber;
169 if (mBound == 0) {
170 OS << mIdTable.end()->first + 1;
171 } else {
172 OS << std::max(mBound, mNextId);
173 }
174 OS << mReserved;
175 }
176
Serialize(OutputWordStream & OS) const177 void Module::Serialize(OutputWordStream &OS) const {
178 SerializeHeader(OS);
179 Entity::Serialize(OS);
180 }
181
addCapability(Capability cap)182 Module *Module::addCapability(Capability cap) {
183 mCapabilities.push_back(mBuilder->MakeCapability(cap));
184 return this;
185 }
186
setMemoryModel(AddressingModel am,MemoryModel mm)187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
188 mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
189 return this;
190 }
191
addExtInstImport(const char * extName)192 Module *Module::addExtInstImport(const char *extName) {
193 ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
194 mExtInstImports.push_back(extInst);
195 if (strcmp(extName, "GLSL.std.450") == 0) {
196 mGLExt = extInst;
197 }
198 return this;
199 }
200
addSource(SourceLanguage lang,int version)201 Module *Module::addSource(SourceLanguage lang, int version) {
202 if (!mDebugInfo) {
203 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
204 }
205 mDebugInfo->addSource(lang, version);
206 return this;
207 }
208
addSourceExtension(const char * ext)209 Module *Module::addSourceExtension(const char *ext) {
210 if (!mDebugInfo) {
211 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
212 }
213 mDebugInfo->addSourceExtension(ext);
214 return this;
215 }
216
addString(const char * str)217 Module *Module::addString(const char *str) {
218 if (!mDebugInfo) {
219 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
220 }
221 mDebugInfo->addString(str);
222 return this;
223 }
224
addEntryPoint(EntryPointDefinition * entry)225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
226 mEntryPoints.push_back(entry);
227 auto newModes = entry->getExecutionModes();
228 mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
229 newModes.end());
230 return this;
231 }
232
findStringOfPrefix(const char * prefix) const233 const std::string Module::findStringOfPrefix(const char *prefix) const {
234 if (!mDebugInfo) {
235 return std::string();
236 }
237 return mDebugInfo->findStringOfPrefix(prefix);
238 }
239
getGlobalSection()240 GlobalSection *Module::getGlobalSection() {
241 if (!mGlobals) {
242 mGlobals.reset(new GlobalSection());
243 }
244 return mGlobals.get();
245 }
246
getConstant(TypeIntInst * type,int32_t value)247 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
248 return getGlobalSection()->getConstant(type, value);
249 }
250
getConstant(TypeIntInst * type,uint32_t value)251 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
252 return getGlobalSection()->getConstant(type, value);
253 }
254
getConstant(TypeFloatInst * type,float value)255 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
256 return getGlobalSection()->getConstant(type, value);
257 }
258
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)259 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
260 ConstantInst *components[],
261 size_t width) {
262 return getGlobalSection()->getConstantComposite(type, components, width);
263 }
264
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2)265 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
266 ConstantInst *comp0,
267 ConstantInst *comp1,
268 ConstantInst *comp2) {
269 // TODO: verify that component types are the same and consistent with the
270 // resulting vector type
271 ConstantInst *comps[] = {comp0, comp1, comp2};
272 return getConstantComposite(type, comps, 3);
273 }
274
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2,ConstantInst * comp3)275 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
276 ConstantInst *comp0,
277 ConstantInst *comp1,
278 ConstantInst *comp2,
279 ConstantInst *comp3) {
280 // TODO: verify that component types are the same and consistent with the
281 // resulting vector type
282 ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
283 return getConstantComposite(type, comps, 4);
284 }
285
getVoidType()286 TypeVoidInst *Module::getVoidType() {
287 return getGlobalSection()->getVoidType();
288 }
289
getIntType(int bits,bool isSigned)290 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
291 return getGlobalSection()->getIntType(bits, isSigned);
292 }
293
getUnsignedIntType(int bits)294 TypeIntInst *Module::getUnsignedIntType(int bits) {
295 return getIntType(bits, false);
296 }
297
getFloatType(int bits)298 TypeFloatInst *Module::getFloatType(int bits) {
299 return getGlobalSection()->getFloatType(bits);
300 }
301
getVectorType(Instruction * componentType,int width)302 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
303 return getGlobalSection()->getVectorType(componentType, width);
304 }
305
getPointerType(StorageClass storage,Instruction * pointeeType)306 TypePointerInst *Module::getPointerType(StorageClass storage,
307 Instruction *pointeeType) {
308 return getGlobalSection()->getPointerType(storage, pointeeType);
309 }
310
getRuntimeArrayType(Instruction * elementType)311 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
312 return getGlobalSection()->getRuntimeArrayType(elementType);
313 }
314
getStructType(Instruction * fieldType[],int numField)315 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
316 return getGlobalSection()->getStructType(fieldType, numField);
317 }
318
getStructType(Instruction * fieldType)319 TypeStructInst *Module::getStructType(Instruction *fieldType) {
320 return getStructType(&fieldType, 1);
321 }
322
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)323 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
324 Instruction *const argType[],
325 size_t numArg) {
326 return getGlobalSection()->getFunctionType(retType, argType, numArg);
327 }
328
329 TypeFunctionInst *
getFunctionType(Instruction * retType,const std::vector<Instruction * > & argTypes)330 Module::getFunctionType(Instruction *retType,
331 const std::vector<Instruction *> &argTypes) {
332 return getGlobalSection()->getFunctionType(retType, argTypes.data(),
333 argTypes.size());
334 }
335
getSize(TypeVoidInst *)336 size_t Module::getSize(TypeVoidInst *) { return 0; }
337
getSize(TypeIntInst * intTy)338 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
339
getSize(TypeFloatInst * fpTy)340 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
341
getSize(TypeVectorInst * vTy)342 size_t Module::getSize(TypeVectorInst *vTy) {
343 return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
344 }
345
getSize(TypePointerInst *)346 size_t Module::getSize(TypePointerInst *) {
347 return 4; // TODO: or 8?
348 }
349
getSize(TypeStructInst * structTy)350 size_t Module::getSize(TypeStructInst *structTy) {
351 size_t sz = 0;
352 for (auto ty : structTy->mOperand1) {
353 sz += getSize(ty.mInstruction);
354 }
355 return sz;
356 }
357
getSize(TypeFunctionInst *)358 size_t Module::getSize(TypeFunctionInst *) {
359 return 4; // TODO: or 8? Is this just the size of a pointer?
360 }
361
getSize(Instruction * inst)362 size_t Module::getSize(Instruction *inst) {
363 switch (inst->getOpCode()) {
364 case OpTypeVoid:
365 return getSize(static_cast<TypeVoidInst *>(inst));
366 case OpTypeInt:
367 return getSize(static_cast<TypeIntInst *>(inst));
368 case OpTypeFloat:
369 return getSize(static_cast<TypeFloatInst *>(inst));
370 case OpTypeVector:
371 return getSize(static_cast<TypeVectorInst *>(inst));
372 case OpTypeStruct:
373 return getSize(static_cast<TypeStructInst *>(inst));
374 case OpTypeFunction:
375 return getSize(static_cast<TypeFunctionInst *>(inst));
376 default:
377 return 0;
378 }
379 }
380
addFunctionDefinition(FunctionDefinition * func)381 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
382 mFunctionDefinitions.push_back(func);
383 return this;
384 }
385
lookupByName(const char * name) const386 Instruction *Module::lookupByName(const char *name) const {
387 return mDebugInfo->lookupByName(name);
388 }
389
390 FunctionDefinition *
getFunctionDefinitionFromInstruction(FunctionInst * inst) const391 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
392 for (auto fdef : mFunctionDefinitions) {
393 if (fdef->getInstruction() == inst) {
394 return fdef;
395 }
396 }
397 return nullptr;
398 }
399
400 FunctionDefinition *
lookupFunctionDefinitionByName(const char * name) const401 Module::lookupFunctionDefinitionByName(const char *name) const {
402 FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
403 return getFunctionDefinitionFromInstruction(inst);
404 }
405
lookupNameByInstruction(const Instruction * inst) const406 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
407 return mDebugInfo->lookupNameByInstruction(inst);
408 }
409
getInvocationId()410 VariableInst *Module::getInvocationId() {
411 return getGlobalSection()->getInvocationId();
412 }
413
getNumWorkgroups()414 VariableInst *Module::getNumWorkgroups() {
415 return getGlobalSection()->getNumWorkgroups();
416 }
417
addStructType(TypeStructInst * structType)418 Module *Module::addStructType(TypeStructInst *structType) {
419 getGlobalSection()->addStructType(structType);
420 return this;
421 }
422
addVariable(VariableInst * var)423 Module *Module::addVariable(VariableInst *var) {
424 getGlobalSection()->addVariable(var);
425 return this;
426 }
427
consolidateAnnotations()428 void Module::consolidateAnnotations() {
429 std::vector<Instruction *> annotations(mAnnotations->begin(),
430 mAnnotations->end());
431 std::unique_ptr<IVisitor> v(
432 CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
433 const auto &ann = inst->getAnnotations();
434 annotations.insert(annotations.end(), ann.begin(), ann.end());
435 }));
436 v->visit(this);
437 mAnnotations->clear();
438 mAnnotations->addAnnotations(annotations.begin(), annotations.end());
439 }
440
EntryPointDefinition(Builder * builder,ExecutionModel execModel,FunctionDefinition * func,const char * name)441 EntryPointDefinition::EntryPointDefinition(Builder *builder,
442 ExecutionModel execModel,
443 FunctionDefinition *func,
444 const char *name)
445 : Entity(builder), mFunction(func->getInstruction()),
446 mExecutionModel(execModel) {
447 mName = strndup(name, strlen(name));
448 mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
449 (void)mExecutionModel; // suppress unused private field warning
450 }
451
DeserializeInternal(InputWordStream & IS)452 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
453 if (IS.empty()) {
454 return false;
455 }
456
457 if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
458 return true;
459 }
460
461 return false;
462 }
463
464 EntryPointDefinition *
applyExecutionMode(ExecutionModeInst * mode)465 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
466 if (mode->mOperand1.mInstruction == mFunction) {
467 addExecutionMode(mode);
468 }
469 return this;
470 }
471
addToInterface(VariableInst * var)472 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
473 mInterface.push_back(var);
474 mEntryPointInst->mOperand4.push_back(var);
475 return this;
476 }
477
setLocalSize(uint32_t width,uint32_t height,uint32_t depth)478 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
479 uint32_t height,
480 uint32_t depth) {
481 mLocalSize.mWidth = width;
482 mLocalSize.mHeight = height;
483 mLocalSize.mDepth = depth;
484
485 auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
486 mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
487
488 addExecutionMode(mode);
489
490 return this;
491 }
492
DeserializeInternal(InputWordStream & IS)493 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
494 while (true) {
495 if (auto str = Deserialize<StringInst>(IS)) {
496 mSources.push_back(str);
497 } else if (auto src = Deserialize<SourceInst>(IS)) {
498 mSources.push_back(src);
499 } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
500 mSources.push_back(srcExt);
501 } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
502 mSources.push_back(srcCont);
503 } else {
504 break;
505 }
506 }
507
508 while (true) {
509 if (auto name = Deserialize<NameInst>(IS)) {
510 mNames.push_back(name);
511 } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
512 mNames.push_back(memName);
513 } else {
514 break;
515 }
516 }
517
518 return true;
519 }
520
addSource(SourceLanguage lang,int version)521 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
522 int version) {
523 SourceInst *source = mBuilder->MakeSource(lang, version);
524 mSources.push_back(source);
525 return this;
526 }
527
addSourceExtension(const char * ext)528 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
529 SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
530 mSources.push_back(inst);
531 return this;
532 }
533
addString(const char * str)534 DebugInfoSection *DebugInfoSection::addString(const char *str) {
535 StringInst *source = mBuilder->MakeString(str);
536 mSources.push_back(source);
537 return this;
538 }
539
findStringOfPrefix(const char * prefix)540 std::string DebugInfoSection::findStringOfPrefix(const char *prefix) {
541 auto it = std::find_if(
542 mSources.begin(), mSources.end(), [prefix](Instruction *inst) -> bool {
543 if (inst->getOpCode() != OpString) {
544 return false;
545 }
546 const StringInst *strInst = static_cast<const StringInst *>(inst);
547 const std::string &str = strInst->mOperand1;
548 return str.find(prefix) == 0;
549 });
550 if (it == mSources.end()) {
551 return "";
552 }
553 StringInst *strInst = static_cast<StringInst *>(*it);
554 return strInst->mOperand1;
555 }
556
lookupByName(const char * name) const557 Instruction *DebugInfoSection::lookupByName(const char *name) const {
558 for (auto inst : mNames) {
559 if (inst->getOpCode() == OpName) {
560 NameInst *nameInst = static_cast<NameInst *>(inst);
561 if (nameInst->mOperand2.compare(name) == 0) {
562 return nameInst->mOperand1.mInstruction;
563 }
564 }
565 // Ignore member names
566 }
567 return nullptr;
568 }
569
570 const char *
lookupNameByInstruction(const Instruction * target) const571 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
572 for (auto inst : mNames) {
573 if (inst->getOpCode() == OpName) {
574 NameInst *nameInst = static_cast<NameInst *>(inst);
575 if (nameInst->mOperand1.mInstruction == target) {
576 return nameInst->mOperand2.c_str();
577 }
578 }
579 // Ignore member names
580 }
581 return nullptr;
582 }
583
AnnotationSection()584 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
585
AnnotationSection(Builder * b)586 AnnotationSection::AnnotationSection(Builder *b)
587 : Entity(b), mAnnotationsDeleter(mAnnotations) {}
588
DeserializeInternal(InputWordStream & IS)589 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
590 while (true) {
591 if (auto decor = Deserialize<DecorateInst>(IS)) {
592 mAnnotations.push_back(decor);
593 } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
594 mAnnotations.push_back(decor);
595 } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
596 mAnnotations.push_back(decor);
597 } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
598 mAnnotations.push_back(decor);
599 } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
600 mAnnotations.push_back(decor);
601 } else {
602 break;
603 }
604 }
605 return true;
606 }
607
GlobalSection()608 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
609
GlobalSection(Builder * builder)610 GlobalSection::GlobalSection(Builder *builder)
611 : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
612
613 namespace {
614
615 template <typename T>
findOrCreate(std::function<bool (T *)> criteria,std::function<T * ()> factory,std::vector<Instruction * > * globals)616 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
617 std::vector<Instruction *> *globals) {
618 T *derived;
619 for (auto inst : *globals) {
620 if (inst->getOpCode() == T::mOpCode) {
621 T *derived = static_cast<T *>(inst);
622 if (criteria(derived)) {
623 return derived;
624 }
625 }
626 }
627 derived = factory();
628 globals->push_back(derived);
629 return derived;
630 }
631
632 } // anonymous namespace
633
DeserializeInternal(InputWordStream & IS)634 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
635 while (true) {
636 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \
637 if (auto typeInst = Deserialize<INST_CLASS>(IS)) { \
638 mGlobalDefs.push_back(typeInst); \
639 continue; \
640 }
641 #include "const_inst_dispatches_generated.h"
642 #include "type_inst_dispatches_generated.h"
643 #undef HANDLE_INSTRUCTION
644
645 if (auto globalInst = Deserialize<VariableInst>(IS)) {
646 // Check if this is function scoped
647 if (globalInst->mOperand1 == StorageClass::Function) {
648 Module::errs() << "warning: Variable (id = " << globalInst->mResult;
649 Module::errs() << ") has function scope in global section.\n";
650 // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
651 // As a workaround, accept such SPIR-V code here, and fix it up later
652 // in the rs2spirv compiler by correcting the storage class.
653 // In a stricter deserializer, such code should be rejected, and we
654 // should return false here.
655 }
656 mGlobalDefs.push_back(globalInst);
657 continue;
658 }
659
660 if (auto globalInst = Deserialize<UndefInst>(IS)) {
661 mGlobalDefs.push_back(globalInst);
662 continue;
663 }
664 break;
665 }
666 return true;
667 }
668
getConstant(TypeIntInst * type,int32_t value)669 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
670 return findOrCreate<ConstantInst>(
671 [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
672 [=]() -> ConstantInst * {
673 LiteralContextDependentNumber cdn = {.intValue = value};
674 return mBuilder->MakeConstant(type, cdn);
675 },
676 &mGlobalDefs);
677 }
678
getConstant(TypeIntInst * type,uint32_t value)679 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
680 return findOrCreate<ConstantInst>(
681 [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
682 [=]() -> ConstantInst * {
683 LiteralContextDependentNumber cdn = {.intValue = (int)value};
684 return mBuilder->MakeConstant(type, cdn);
685 },
686 &mGlobalDefs);
687 }
688
getConstant(TypeFloatInst * type,float value)689 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
690 return findOrCreate<ConstantInst>(
691 [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
692 [=]() -> ConstantInst * {
693 LiteralContextDependentNumber cdn = {.floatValue = value};
694 return mBuilder->MakeConstant(type, cdn);
695 },
696 &mGlobalDefs);
697 }
698
699 ConstantCompositeInst *
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)700 GlobalSection::getConstantComposite(TypeVectorInst *type,
701 ConstantInst *components[], size_t width) {
702 return findOrCreate<ConstantCompositeInst>(
703 [=](ConstantCompositeInst *c) {
704 if (c->mOperand1.size() != width) {
705 return false;
706 }
707 for (size_t i = 0; i < width; i++) {
708 if (c->mOperand1[i].mInstruction != components[i]) {
709 return false;
710 }
711 }
712 return true;
713 },
714 [=]() -> ConstantCompositeInst * {
715 ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
716 for (size_t i = 0; i < width; i++) {
717 c->mOperand1.push_back(components[i]);
718 }
719 return c;
720 },
721 &mGlobalDefs);
722 }
723
getVoidType()724 TypeVoidInst *GlobalSection::getVoidType() {
725 return findOrCreate<TypeVoidInst>(
726 [=](TypeVoidInst *) -> bool { return true; },
727 [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
728 &mGlobalDefs);
729 }
730
getIntType(int bits,bool isSigned)731 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
732 if (isSigned) {
733 switch (bits) {
734 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED) \
735 case BITS: { \
736 return findOrCreate<TypeIntInst>( \
737 [=](TypeIntInst *intTy) -> bool { \
738 return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \
739 }, \
740 [=]() -> TypeIntInst * { \
741 return mBuilder->MakeTypeInt(BITS, SIGNED); \
742 }, \
743 &mGlobalDefs); \
744 }
745 HANDLE_INT_SIZE(Int, 8, 1);
746 HANDLE_INT_SIZE(Int, 16, 1);
747 HANDLE_INT_SIZE(Int, 32, 1);
748 HANDLE_INT_SIZE(Int, 64, 1);
749 default:
750 Module::errs() << "unexpected int type";
751 }
752 } else {
753 switch (bits) {
754 HANDLE_INT_SIZE(UInt, 8, 0);
755 HANDLE_INT_SIZE(UInt, 16, 0);
756 HANDLE_INT_SIZE(UInt, 32, 0);
757 HANDLE_INT_SIZE(UInt, 64, 0);
758 default:
759 Module::errs() << "unexpected int type";
760 }
761 }
762 #undef HANDLE_INT_SIZE
763 return nullptr;
764 }
765
getFloatType(int bits)766 TypeFloatInst *GlobalSection::getFloatType(int bits) {
767 switch (bits) {
768 #define HANDLE_FLOAT_SIZE(BITS) \
769 case BITS: { \
770 return findOrCreate<TypeFloatInst>( \
771 [=](TypeFloatInst *floatTy) -> bool { \
772 return floatTy->mOperand1 == BITS; \
773 }, \
774 [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); }, \
775 &mGlobalDefs); \
776 }
777 HANDLE_FLOAT_SIZE(16);
778 HANDLE_FLOAT_SIZE(32);
779 HANDLE_FLOAT_SIZE(64);
780 default:
781 Module::errs() << "unexpeced floating point type";
782 }
783 #undef HANDLE_FLOAT_SIZE
784 return nullptr;
785 }
786
getVectorType(Instruction * componentType,int width)787 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
788 int width) {
789 // TODO: verify that componentType is basic numeric types
790
791 return findOrCreate<TypeVectorInst>(
792 [=](TypeVectorInst *vecTy) -> bool {
793 return vecTy->mOperand1.mInstruction == componentType &&
794 vecTy->mOperand2 == width;
795 },
796 [=]() -> TypeVectorInst * {
797 return mBuilder->MakeTypeVector(componentType, width);
798 },
799 &mGlobalDefs);
800 }
801
getPointerType(StorageClass storage,Instruction * pointeeType)802 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
803 Instruction *pointeeType) {
804 return findOrCreate<TypePointerInst>(
805 [=](TypePointerInst *type) -> bool {
806 return type->mOperand1 == storage &&
807 type->mOperand2.mInstruction == pointeeType;
808 },
809 [=]() -> TypePointerInst * {
810 return mBuilder->MakeTypePointer(storage, pointeeType);
811 },
812 &mGlobalDefs);
813 }
814
815 TypeRuntimeArrayInst *
getRuntimeArrayType(Instruction * elemType)816 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
817 return findOrCreate<TypeRuntimeArrayInst>(
818 [=](TypeRuntimeArrayInst * /*type*/) -> bool {
819 // return type->mOperand1.mInstruction == elemType;
820 return false;
821 },
822 [=]() -> TypeRuntimeArrayInst * {
823 return mBuilder->MakeTypeRuntimeArray(elemType);
824 },
825 &mGlobalDefs);
826 }
827
getStructType(Instruction * fieldType[],int numField)828 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
829 int numField) {
830 TypeStructInst *structTy = mBuilder->MakeTypeStruct();
831 for (int i = 0; i < numField; i++) {
832 structTy->mOperand1.push_back(fieldType[i]);
833 }
834 mGlobalDefs.push_back(structTy);
835 return structTy;
836 }
837
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)838 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
839 Instruction *const argType[],
840 size_t numArg) {
841 return findOrCreate<TypeFunctionInst>(
842 [=](TypeFunctionInst *type) -> bool {
843 if (type->mOperand1.mInstruction != retType ||
844 type->mOperand2.size() != numArg) {
845 return false;
846 }
847 for (size_t i = 0; i < numArg; i++) {
848 if (type->mOperand2[i].mInstruction != argType[i]) {
849 return false;
850 }
851 }
852 return true;
853 },
854 [=]() -> TypeFunctionInst * {
855 TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
856 for (size_t i = 0; i < numArg; i++) {
857 funcTy->mOperand2.push_back(argType[i]);
858 }
859 return funcTy;
860 },
861 &mGlobalDefs);
862 }
863
addStructType(TypeStructInst * structType)864 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
865 mGlobalDefs.push_back(structType);
866 return this;
867 }
868
addVariable(VariableInst * var)869 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
870 mGlobalDefs.push_back(var);
871 return this;
872 }
873
getInvocationId()874 VariableInst *GlobalSection::getInvocationId() {
875 if (mInvocationId) {
876 return mInvocationId.get();
877 }
878
879 TypeIntInst *UIntTy = getIntType(32, false);
880 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
881 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
882
883 VariableInst *InvocationId =
884 mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
885 InvocationId->decorate(Decoration::BuiltIn)
886 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
887
888 mInvocationId.reset(InvocationId);
889
890 return InvocationId;
891 }
892
getNumWorkgroups()893 VariableInst *GlobalSection::getNumWorkgroups() {
894 if (mNumWorkgroups) {
895 return mNumWorkgroups.get();
896 }
897
898 TypeIntInst *UIntTy = getIntType(32, false);
899 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
900 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
901
902 VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
903 GNum->decorate(Decoration::BuiltIn)
904 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
905
906 mNumWorkgroups.reset(GNum);
907
908 return GNum;
909 }
910
DeserializeInternal(InputWordStream & IS)911 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
912 if (!(mFunc = Deserialize<FunctionInst>(IS))) {
913 return false;
914 }
915
916 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
917
918 if (!(mFuncEnd = Deserialize<FunctionEndInst>(IS))) {
919 return false;
920 }
921
922 return true;
923 }
924
Deserialize(InputWordStream & IS)925 template <> Instruction *Deserialize(InputWordStream &IS) {
926 Instruction *inst;
927
928 switch ((*IS) & 0xFFFF) {
929 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \
930 case OPCODE: \
931 inst = Deserialize<INST_CLASS>(IS); \
932 break;
933 #include "instruction_dispatches_generated.h"
934 #undef HANDLE_INSTRUCTION
935 default:
936 Module::errs() << "unrecognized instruction";
937 inst = nullptr;
938 }
939
940 return inst;
941 }
942
DeserializeInternal(InputWordStream & IS)943 bool Block::DeserializeInternal(InputWordStream &IS) {
944 Instruction *inst;
945 while (((*IS) & 0xFFFF) != OpFunctionEnd &&
946 (inst = Deserialize<Instruction>(IS))) {
947 mInsts.push_back(inst);
948 if (inst->getOpCode() == OpBranch ||
949 inst->getOpCode() == OpBranchConditional ||
950 inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
951 inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
952 inst->getOpCode() == OpUnreachable) {
953 break;
954 }
955 }
956 return !mInsts.empty();
957 }
958
FunctionDefinition()959 FunctionDefinition::FunctionDefinition()
960 : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
961
FunctionDefinition(Builder * builder,FunctionInst * func,FunctionEndInst * end)962 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
963 FunctionEndInst *end)
964 : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
965 mBlocksDeleter(mBlocks) {}
966
DeserializeInternal(InputWordStream & IS)967 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
968 mFunc.reset(Deserialize<FunctionInst>(IS));
969 if (!mFunc) {
970 return false;
971 }
972
973 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
974 DeserializeZeroOrMore<Block>(IS, mBlocks);
975
976 mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
977 if (!mFuncEnd) {
978 return false;
979 }
980
981 return true;
982 }
983
getReturnType() const984 Instruction *FunctionDefinition::getReturnType() const {
985 return mFunc->mResultType.mInstruction;
986 }
987
988 } // namespace spirit
989 } // namespace android
990