1 /*
2 * Copyright 2017, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "module.h"
18
19 #include <set>
20
21 #include "builder.h"
22 #include "core_defs.h"
23 #include "instructions.h"
24 #include "types_generated.h"
25 #include "word_stream.h"
26
27 namespace android {
28 namespace spirit {
29
30 Module *Module::mInstance = nullptr;
31
getCurrentModule()32 Module *Module::getCurrentModule() {
33 if (mInstance == nullptr) {
34 return mInstance = new Module();
35 }
36 return mInstance;
37 }
38
Module()39 Module::Module()
40 : mNextId(1), mCapabilitiesDeleter(mCapabilities),
41 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
42 mEntryPointInstsDeleter(mEntryPointInsts),
43 mExecutionModesDeleter(mExecutionModes),
44 mEntryPointsDeleter(mEntryPoints),
45 mFunctionDefinitionsDeleter(mFunctionDefinitions) {
46 mInstance = this;
47 }
48
Module(Builder * b)49 Module::Module(Builder *b)
50 : Entity(b), mNextId(1), mCapabilitiesDeleter(mCapabilities),
51 mExtensionsDeleter(mExtensions), mExtInstImportsDeleter(mExtInstImports),
52 mEntryPointInstsDeleter(mEntryPointInsts),
53 mExecutionModesDeleter(mExecutionModes),
54 mEntryPointsDeleter(mEntryPoints),
55 mFunctionDefinitionsDeleter(mFunctionDefinitions) {
56 mInstance = this;
57 }
58
resolveIds()59 bool Module::resolveIds() {
60 auto &table = mIdTable;
61
62 std::unique_ptr<IVisitor> v0(
63 CreateInstructionVisitor([&table](Instruction *inst) {
64 if (inst->hasResult()) {
65 table.insert(std::make_pair(inst->getId(), inst));
66 }
67 }));
68 v0->visit(this);
69
70 mNextId = mIdTable.rbegin()->first + 1;
71
72 int err = 0;
73 std::unique_ptr<IVisitor> v(
74 CreateInstructionVisitor([&table, &err](Instruction *inst) {
75 for (auto ref : inst->getAllIdRefs()) {
76 if (ref) {
77 auto it = table.find(ref->mId);
78 if (it != table.end()) {
79 ref->mInstruction = it->second;
80 } else {
81 std::cout << "Found no instruction for id " << ref->mId
82 << std::endl;
83 err++;
84 }
85 }
86 }
87 }));
88 v->visit(this);
89 return err == 0;
90 }
91
DeserializeInternal(InputWordStream & IS)92 bool Module::DeserializeInternal(InputWordStream &IS) {
93 if (IS.empty()) {
94 return false;
95 }
96
97 IS >> &mMagicNumber;
98 if (mMagicNumber != 0x07230203) {
99 errs() << "Wrong Magic Number: " << mMagicNumber;
100 return false;
101 }
102
103 if (IS.empty()) {
104 return false;
105 }
106
107 IS >> &mVersion.mWord;
108 if (mVersion.mBytes[0] != 0 || mVersion.mBytes[3] != 0) {
109 return false;
110 }
111
112 if (IS.empty()) {
113 return false;
114 }
115
116 IS >> &mGeneratorMagicNumber >> &mBound >> &mReserved;
117
118 DeserializeZeroOrMore<CapabilityInst>(IS, mCapabilities);
119 DeserializeZeroOrMore<ExtensionInst>(IS, mExtensions);
120 DeserializeZeroOrMore<ExtInstImportInst>(IS, mExtInstImports);
121
122 mMemoryModel.reset(Deserialize<MemoryModelInst>(IS));
123 if (!mMemoryModel) {
124 errs() << "Missing memory model specification.\n";
125 return false;
126 }
127
128 DeserializeZeroOrMore<EntryPointDefinition>(IS, mEntryPoints);
129 DeserializeZeroOrMore<ExecutionModeInst>(IS, mExecutionModes);
130 for (auto entry : mEntryPoints) {
131 mEntryPointInsts.push_back(entry->getInstruction());
132 for (auto mode : mExecutionModes) {
133 entry->applyExecutionMode(mode);
134 }
135 }
136
137 mDebugInfo.reset(Deserialize<DebugInfoSection>(IS));
138 mAnnotations.reset(Deserialize<AnnotationSection>(IS));
139 mGlobals.reset(Deserialize<GlobalSection>(IS));
140
141 DeserializeZeroOrMore<FunctionDefinition>(IS, mFunctionDefinitions);
142
143 if (mFunctionDefinitions.empty()) {
144 errs() << "Missing function definitions.\n";
145 for (int i = 0; i < 4; i++) {
146 uint32_t w;
147 IS >> &w;
148 std::cout << std::hex << w << " ";
149 }
150 std::cout << std::endl;
151 return false;
152 }
153
154 return true;
155 }
156
initialize()157 void Module::initialize() {
158 mMagicNumber = 0x07230203;
159 mVersion.mMajorMinor = {.mMinorNumber = 1, .mMajorNumber = 1};
160 mGeneratorMagicNumber = 0x00070000;
161 mBound = 0;
162 mReserved = 0;
163 mAnnotations.reset(new AnnotationSection());
164 }
165
SerializeHeader(OutputWordStream & OS) const166 void Module::SerializeHeader(OutputWordStream &OS) const {
167 OS << mMagicNumber;
168 OS << mVersion.mWord << mGeneratorMagicNumber;
169 if (mBound == 0) {
170 OS << mIdTable.end()->first + 1;
171 } else {
172 OS << std::max(mBound, mNextId);
173 }
174 OS << mReserved;
175 }
176
Serialize(OutputWordStream & OS) const177 void Module::Serialize(OutputWordStream &OS) const {
178 SerializeHeader(OS);
179 Entity::Serialize(OS);
180 }
181
addCapability(Capability cap)182 Module *Module::addCapability(Capability cap) {
183 mCapabilities.push_back(mBuilder->MakeCapability(cap));
184 return this;
185 }
186
setMemoryModel(AddressingModel am,MemoryModel mm)187 Module *Module::setMemoryModel(AddressingModel am, MemoryModel mm) {
188 mMemoryModel.reset(mBuilder->MakeMemoryModel(am, mm));
189 return this;
190 }
191
addExtInstImport(const char * extName)192 Module *Module::addExtInstImport(const char *extName) {
193 ExtInstImportInst *extInst = mBuilder->MakeExtInstImport(extName);
194 mExtInstImports.push_back(extInst);
195 if (strcmp(extName, "GLSL.std.450") == 0) {
196 mGLExt = extInst;
197 }
198 return this;
199 }
200
addSource(SourceLanguage lang,int version)201 Module *Module::addSource(SourceLanguage lang, int version) {
202 if (!mDebugInfo) {
203 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
204 }
205 mDebugInfo->addSource(lang, version);
206 return this;
207 }
208
addSourceExtension(const char * ext)209 Module *Module::addSourceExtension(const char *ext) {
210 if (!mDebugInfo) {
211 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
212 }
213 mDebugInfo->addSourceExtension(ext);
214 return this;
215 }
216
addString(const char * str)217 Module *Module::addString(const char *str) {
218 if (!mDebugInfo) {
219 mDebugInfo.reset(mBuilder->MakeDebugInfoSection());
220 }
221 mDebugInfo->addString(str);
222 return this;
223 }
224
addEntryPoint(EntryPointDefinition * entry)225 Module *Module::addEntryPoint(EntryPointDefinition *entry) {
226 mEntryPoints.push_back(entry);
227 auto newModes = entry->getExecutionModes();
228 mExecutionModes.insert(mExecutionModes.end(), newModes.begin(),
229 newModes.end());
230 return this;
231 }
232
findStringOfPrefix(const char * prefix) const233 const std::string Module::findStringOfPrefix(const char *prefix) const {
234 if (!mDebugInfo) {
235 return std::string();
236 }
237 return mDebugInfo->findStringOfPrefix(prefix);
238 }
239
getGlobalSection()240 GlobalSection *Module::getGlobalSection() {
241 if (!mGlobals) {
242 mGlobals.reset(new GlobalSection());
243 }
244 return mGlobals.get();
245 }
246
getConstant(TypeIntInst * type,int32_t value)247 ConstantInst *Module::getConstant(TypeIntInst *type, int32_t value) {
248 return getGlobalSection()->getConstant(type, value);
249 }
250
getConstant(TypeIntInst * type,uint32_t value)251 ConstantInst *Module::getConstant(TypeIntInst *type, uint32_t value) {
252 return getGlobalSection()->getConstant(type, value);
253 }
254
getConstant(TypeFloatInst * type,float value)255 ConstantInst *Module::getConstant(TypeFloatInst *type, float value) {
256 return getGlobalSection()->getConstant(type, value);
257 }
258
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)259 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
260 ConstantInst *components[],
261 size_t width) {
262 return getGlobalSection()->getConstantComposite(type, components, width);
263 }
264
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2)265 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
266 ConstantInst *comp0,
267 ConstantInst *comp1,
268 ConstantInst *comp2) {
269 // TODO: verify that component types are the same and consistent with the
270 // resulting vector type
271 ConstantInst *comps[] = {comp0, comp1, comp2};
272 return getConstantComposite(type, comps, 3);
273 }
274
getConstantComposite(TypeVectorInst * type,ConstantInst * comp0,ConstantInst * comp1,ConstantInst * comp2,ConstantInst * comp3)275 ConstantCompositeInst *Module::getConstantComposite(TypeVectorInst *type,
276 ConstantInst *comp0,
277 ConstantInst *comp1,
278 ConstantInst *comp2,
279 ConstantInst *comp3) {
280 // TODO: verify that component types are the same and consistent with the
281 // resulting vector type
282 ConstantInst *comps[] = {comp0, comp1, comp2, comp3};
283 return getConstantComposite(type, comps, 4);
284 }
285
getVoidType()286 TypeVoidInst *Module::getVoidType() {
287 return getGlobalSection()->getVoidType();
288 }
289
getIntType(int bits,bool isSigned)290 TypeIntInst *Module::getIntType(int bits, bool isSigned) {
291 return getGlobalSection()->getIntType(bits, isSigned);
292 }
293
getUnsignedIntType(int bits)294 TypeIntInst *Module::getUnsignedIntType(int bits) {
295 return getIntType(bits, false);
296 }
297
getFloatType(int bits)298 TypeFloatInst *Module::getFloatType(int bits) {
299 return getGlobalSection()->getFloatType(bits);
300 }
301
getVectorType(Instruction * componentType,int width)302 TypeVectorInst *Module::getVectorType(Instruction *componentType, int width) {
303 return getGlobalSection()->getVectorType(componentType, width);
304 }
305
getPointerType(StorageClass storage,Instruction * pointeeType)306 TypePointerInst *Module::getPointerType(StorageClass storage,
307 Instruction *pointeeType) {
308 return getGlobalSection()->getPointerType(storage, pointeeType);
309 }
310
getRuntimeArrayType(Instruction * elementType)311 TypeRuntimeArrayInst *Module::getRuntimeArrayType(Instruction *elementType) {
312 return getGlobalSection()->getRuntimeArrayType(elementType);
313 }
314
getStructType(Instruction * fieldType[],int numField)315 TypeStructInst *Module::getStructType(Instruction *fieldType[], int numField) {
316 return getGlobalSection()->getStructType(fieldType, numField);
317 }
318
getStructType(Instruction * fieldType)319 TypeStructInst *Module::getStructType(Instruction *fieldType) {
320 return getStructType(&fieldType, 1);
321 }
322
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)323 TypeFunctionInst *Module::getFunctionType(Instruction *retType,
324 Instruction *const argType[],
325 size_t numArg) {
326 return getGlobalSection()->getFunctionType(retType, argType, numArg);
327 }
328
329 TypeFunctionInst *
getFunctionType(Instruction * retType,const std::vector<Instruction * > & argTypes)330 Module::getFunctionType(Instruction *retType,
331 const std::vector<Instruction *> &argTypes) {
332 return getGlobalSection()->getFunctionType(retType, argTypes.data(),
333 argTypes.size());
334 }
335
getSize(TypeVoidInst *)336 size_t Module::getSize(TypeVoidInst *) { return 0; }
337
getSize(TypeIntInst * intTy)338 size_t Module::getSize(TypeIntInst *intTy) { return intTy->mOperand1 / 8; }
339
getSize(TypeFloatInst * fpTy)340 size_t Module::getSize(TypeFloatInst *fpTy) { return fpTy->mOperand1 / 8; }
341
getSize(TypeVectorInst * vTy)342 size_t Module::getSize(TypeVectorInst *vTy) {
343 return getSize(vTy->mOperand1.mInstruction) * vTy->mOperand2;
344 }
345
getSize(TypePointerInst *)346 size_t Module::getSize(TypePointerInst *) {
347 return 4; // TODO: or 8?
348 }
349
getSize(TypeStructInst * structTy)350 size_t Module::getSize(TypeStructInst *structTy) {
351 size_t sz = 0;
352 for (auto ty : structTy->mOperand1) {
353 sz += getSize(ty.mInstruction);
354 }
355 return sz;
356 }
357
getSize(TypeFunctionInst *)358 size_t Module::getSize(TypeFunctionInst *) {
359 return 4; // TODO: or 8? Is this just the size of a pointer?
360 }
361
getSize(Instruction * inst)362 size_t Module::getSize(Instruction *inst) {
363 switch (inst->getOpCode()) {
364 case OpTypeVoid:
365 return getSize(static_cast<TypeVoidInst *>(inst));
366 case OpTypeInt:
367 return getSize(static_cast<TypeIntInst *>(inst));
368 case OpTypeFloat:
369 return getSize(static_cast<TypeFloatInst *>(inst));
370 case OpTypeVector:
371 return getSize(static_cast<TypeVectorInst *>(inst));
372 case OpTypeStruct:
373 return getSize(static_cast<TypeStructInst *>(inst));
374 case OpTypeFunction:
375 return getSize(static_cast<TypeFunctionInst *>(inst));
376 default:
377 return 0;
378 }
379 }
380
addFunctionDefinition(FunctionDefinition * func)381 Module *Module::addFunctionDefinition(FunctionDefinition *func) {
382 mFunctionDefinitions.push_back(func);
383 return this;
384 }
385
lookupByName(const char * name) const386 Instruction *Module::lookupByName(const char *name) const {
387 return mDebugInfo->lookupByName(name);
388 }
389
390 FunctionDefinition *
getFunctionDefinitionFromInstruction(FunctionInst * inst) const391 Module::getFunctionDefinitionFromInstruction(FunctionInst *inst) const {
392 for (auto fdef : mFunctionDefinitions) {
393 if (fdef->getInstruction() == inst) {
394 return fdef;
395 }
396 }
397 return nullptr;
398 }
399
400 FunctionDefinition *
lookupFunctionDefinitionByName(const char * name) const401 Module::lookupFunctionDefinitionByName(const char *name) const {
402 FunctionInst *inst = static_cast<FunctionInst *>(lookupByName(name));
403 return getFunctionDefinitionFromInstruction(inst);
404 }
405
lookupNameByInstruction(const Instruction * inst) const406 const char *Module::lookupNameByInstruction(const Instruction *inst) const {
407 return mDebugInfo->lookupNameByInstruction(inst);
408 }
409
getInvocationId()410 VariableInst *Module::getInvocationId() {
411 return getGlobalSection()->getInvocationId();
412 }
413
getNumWorkgroups()414 VariableInst *Module::getNumWorkgroups() {
415 return getGlobalSection()->getNumWorkgroups();
416 }
417
addStructType(TypeStructInst * structType)418 Module *Module::addStructType(TypeStructInst *structType) {
419 getGlobalSection()->addStructType(structType);
420 return this;
421 }
422
addVariable(VariableInst * var)423 Module *Module::addVariable(VariableInst *var) {
424 getGlobalSection()->addVariable(var);
425 return this;
426 }
427
consolidateAnnotations()428 void Module::consolidateAnnotations() {
429 std::vector<Instruction *> annotations(mAnnotations->begin(),
430 mAnnotations->end());
431 std::unique_ptr<IVisitor> v(
432 CreateInstructionVisitor([&annotations](Instruction *inst) -> void {
433 const auto &ann = inst->getAnnotations();
434 annotations.insert(annotations.end(), ann.begin(), ann.end());
435 }));
436 v->visit(this);
437 mAnnotations->clear();
438 mAnnotations->addAnnotations(annotations.begin(), annotations.end());
439 }
440
EntryPointDefinition(Builder * builder,ExecutionModel execModel,FunctionDefinition * func,const char * name)441 EntryPointDefinition::EntryPointDefinition(Builder *builder,
442 ExecutionModel execModel,
443 FunctionDefinition *func,
444 const char *name)
445 : Entity(builder), mFunction(func->getInstruction()),
446 mExecutionModel(execModel) {
447 mName = strndup(name, strlen(name));
448 mEntryPointInst = mBuilder->MakeEntryPoint(execModel, mFunction, mName);
449 }
450
DeserializeInternal(InputWordStream & IS)451 bool EntryPointDefinition::DeserializeInternal(InputWordStream &IS) {
452 if (IS.empty()) {
453 return false;
454 }
455
456 if ((mEntryPointInst = Deserialize<EntryPointInst>(IS))) {
457 return true;
458 }
459
460 return false;
461 }
462
463 EntryPointDefinition *
applyExecutionMode(ExecutionModeInst * mode)464 EntryPointDefinition::applyExecutionMode(ExecutionModeInst *mode) {
465 if (mode->mOperand1.mInstruction == mFunction) {
466 addExecutionMode(mode);
467 }
468 return this;
469 }
470
addToInterface(VariableInst * var)471 EntryPointDefinition *EntryPointDefinition::addToInterface(VariableInst *var) {
472 mInterface.push_back(var);
473 mEntryPointInst->mOperand4.push_back(var);
474 return this;
475 }
476
setLocalSize(uint32_t width,uint32_t height,uint32_t depth)477 EntryPointDefinition *EntryPointDefinition::setLocalSize(uint32_t width,
478 uint32_t height,
479 uint32_t depth) {
480 mLocalSize.mWidth = width;
481 mLocalSize.mHeight = height;
482 mLocalSize.mDepth = depth;
483
484 auto mode = mBuilder->MakeExecutionMode(mFunction, ExecutionMode::LocalSize);
485 mode->addExtraOperand(width)->addExtraOperand(height)->addExtraOperand(depth);
486
487 addExecutionMode(mode);
488
489 return this;
490 }
491
DeserializeInternal(InputWordStream & IS)492 bool DebugInfoSection::DeserializeInternal(InputWordStream &IS) {
493 while (true) {
494 if (auto str = Deserialize<StringInst>(IS)) {
495 mSources.push_back(str);
496 } else if (auto src = Deserialize<SourceInst>(IS)) {
497 mSources.push_back(src);
498 } else if (auto srcExt = Deserialize<SourceExtensionInst>(IS)) {
499 mSources.push_back(srcExt);
500 } else if (auto srcCont = Deserialize<SourceContinuedInst>(IS)) {
501 mSources.push_back(srcCont);
502 } else {
503 break;
504 }
505 }
506
507 while (true) {
508 if (auto name = Deserialize<NameInst>(IS)) {
509 mNames.push_back(name);
510 } else if (auto memName = Deserialize<MemberNameInst>(IS)) {
511 mNames.push_back(memName);
512 } else {
513 break;
514 }
515 }
516
517 return true;
518 }
519
addSource(SourceLanguage lang,int version)520 DebugInfoSection *DebugInfoSection::addSource(SourceLanguage lang,
521 int version) {
522 SourceInst *source = mBuilder->MakeSource(lang, version);
523 mSources.push_back(source);
524 return this;
525 }
526
addSourceExtension(const char * ext)527 DebugInfoSection *DebugInfoSection::addSourceExtension(const char *ext) {
528 SourceExtensionInst *inst = mBuilder->MakeSourceExtension(ext);
529 mSources.push_back(inst);
530 return this;
531 }
532
addString(const char * str)533 DebugInfoSection *DebugInfoSection::addString(const char *str) {
534 StringInst *source = mBuilder->MakeString(str);
535 mSources.push_back(source);
536 return this;
537 }
538
findStringOfPrefix(const char * prefix)539 std::string DebugInfoSection::findStringOfPrefix(const char *prefix) {
540 auto it = std::find_if(
541 mSources.begin(), mSources.end(), [prefix](Instruction *inst) -> bool {
542 if (inst->getOpCode() != OpString) {
543 return false;
544 }
545 const StringInst *strInst = static_cast<const StringInst *>(inst);
546 const std::string &str = strInst->mOperand1;
547 return str.find(prefix) == 0;
548 });
549 if (it == mSources.end()) {
550 return "";
551 }
552 StringInst *strInst = static_cast<StringInst *>(*it);
553 return strInst->mOperand1;
554 }
555
lookupByName(const char * name) const556 Instruction *DebugInfoSection::lookupByName(const char *name) const {
557 for (auto inst : mNames) {
558 if (inst->getOpCode() == OpName) {
559 NameInst *nameInst = static_cast<NameInst *>(inst);
560 if (nameInst->mOperand2.compare(name) == 0) {
561 return nameInst->mOperand1.mInstruction;
562 }
563 }
564 // Ignore member names
565 }
566 return nullptr;
567 }
568
569 const char *
lookupNameByInstruction(const Instruction * target) const570 DebugInfoSection::lookupNameByInstruction(const Instruction *target) const {
571 for (auto inst : mNames) {
572 if (inst->getOpCode() == OpName) {
573 NameInst *nameInst = static_cast<NameInst *>(inst);
574 if (nameInst->mOperand1.mInstruction == target) {
575 return nameInst->mOperand2.c_str();
576 }
577 }
578 // Ignore member names
579 }
580 return nullptr;
581 }
582
AnnotationSection()583 AnnotationSection::AnnotationSection() : mAnnotationsDeleter(mAnnotations) {}
584
AnnotationSection(Builder * b)585 AnnotationSection::AnnotationSection(Builder *b)
586 : Entity(b), mAnnotationsDeleter(mAnnotations) {}
587
DeserializeInternal(InputWordStream & IS)588 bool AnnotationSection::DeserializeInternal(InputWordStream &IS) {
589 while (true) {
590 if (auto decor = Deserialize<DecorateInst>(IS)) {
591 mAnnotations.push_back(decor);
592 } else if (auto decor = Deserialize<MemberDecorateInst>(IS)) {
593 mAnnotations.push_back(decor);
594 } else if (auto decor = Deserialize<GroupDecorateInst>(IS)) {
595 mAnnotations.push_back(decor);
596 } else if (auto decor = Deserialize<GroupMemberDecorateInst>(IS)) {
597 mAnnotations.push_back(decor);
598 } else if (auto decor = Deserialize<DecorationGroupInst>(IS)) {
599 mAnnotations.push_back(decor);
600 } else {
601 break;
602 }
603 }
604 return true;
605 }
606
GlobalSection()607 GlobalSection::GlobalSection() : mGlobalDefsDeleter(mGlobalDefs) {}
608
GlobalSection(Builder * builder)609 GlobalSection::GlobalSection(Builder *builder)
610 : Entity(builder), mGlobalDefsDeleter(mGlobalDefs) {}
611
612 namespace {
613
614 template <typename T>
findOrCreate(std::function<bool (T *)> criteria,std::function<T * ()> factory,std::vector<Instruction * > * globals)615 T *findOrCreate(std::function<bool(T *)> criteria, std::function<T *()> factory,
616 std::vector<Instruction *> *globals) {
617 T *derived;
618 for (auto inst : *globals) {
619 if (inst->getOpCode() == T::mOpCode) {
620 T *derived = static_cast<T *>(inst);
621 if (criteria(derived)) {
622 return derived;
623 }
624 }
625 }
626 derived = factory();
627 globals->push_back(derived);
628 return derived;
629 }
630
631 } // anonymous namespace
632
DeserializeInternal(InputWordStream & IS)633 bool GlobalSection::DeserializeInternal(InputWordStream &IS) {
634 while (true) {
635 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \
636 if (auto typeInst = Deserialize<INST_CLASS>(IS)) { \
637 mGlobalDefs.push_back(typeInst); \
638 continue; \
639 }
640 #include "const_inst_dispatches_generated.h"
641 #include "type_inst_dispatches_generated.h"
642 #undef HANDLE_INSTRUCTION
643
644 if (auto globalInst = Deserialize<VariableInst>(IS)) {
645 // Check if this is function scoped
646 if (globalInst->mOperand1 == StorageClass::Function) {
647 Module::errs() << "warning: Variable (id = " << globalInst->mResult;
648 Module::errs() << ") has function scope in global section.\n";
649 // Khronos LLVM-SPIRV convertor emits "Function" storage-class globals.
650 // As a workaround, accept such SPIR-V code here, and fix it up later
651 // in the rs2spirv compiler by correcting the storage class.
652 // In a stricter deserializer, such code should be rejected, and we
653 // should return false here.
654 }
655 mGlobalDefs.push_back(globalInst);
656 continue;
657 }
658
659 if (auto globalInst = Deserialize<UndefInst>(IS)) {
660 mGlobalDefs.push_back(globalInst);
661 continue;
662 }
663 break;
664 }
665 return true;
666 }
667
getConstant(TypeIntInst * type,int32_t value)668 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, int32_t value) {
669 return findOrCreate<ConstantInst>(
670 [=](ConstantInst *c) { return c->mOperand1.intValue == value; },
671 [=]() -> ConstantInst * {
672 LiteralContextDependentNumber cdn = {.intValue = value};
673 return mBuilder->MakeConstant(type, cdn);
674 },
675 &mGlobalDefs);
676 }
677
getConstant(TypeIntInst * type,uint32_t value)678 ConstantInst *GlobalSection::getConstant(TypeIntInst *type, uint32_t value) {
679 return findOrCreate<ConstantInst>(
680 [=](ConstantInst *c) { return c->mOperand1.intValue == (int)value; },
681 [=]() -> ConstantInst * {
682 LiteralContextDependentNumber cdn = {.intValue = (int)value};
683 return mBuilder->MakeConstant(type, cdn);
684 },
685 &mGlobalDefs);
686 }
687
getConstant(TypeFloatInst * type,float value)688 ConstantInst *GlobalSection::getConstant(TypeFloatInst *type, float value) {
689 return findOrCreate<ConstantInst>(
690 [=](ConstantInst *c) { return c->mOperand1.floatValue == value; },
691 [=]() -> ConstantInst * {
692 LiteralContextDependentNumber cdn = {.floatValue = value};
693 return mBuilder->MakeConstant(type, cdn);
694 },
695 &mGlobalDefs);
696 }
697
698 ConstantCompositeInst *
getConstantComposite(TypeVectorInst * type,ConstantInst * components[],size_t width)699 GlobalSection::getConstantComposite(TypeVectorInst *type,
700 ConstantInst *components[], size_t width) {
701 return findOrCreate<ConstantCompositeInst>(
702 [=](ConstantCompositeInst *c) {
703 if (c->mOperand1.size() != width) {
704 return false;
705 }
706 for (size_t i = 0; i < width; i++) {
707 if (c->mOperand1[i].mInstruction != components[i]) {
708 return false;
709 }
710 }
711 return true;
712 },
713 [=]() -> ConstantCompositeInst * {
714 ConstantCompositeInst *c = mBuilder->MakeConstantComposite(type);
715 for (size_t i = 0; i < width; i++) {
716 c->mOperand1.push_back(components[i]);
717 }
718 return c;
719 },
720 &mGlobalDefs);
721 }
722
getVoidType()723 TypeVoidInst *GlobalSection::getVoidType() {
724 return findOrCreate<TypeVoidInst>(
725 [=](TypeVoidInst *) -> bool { return true; },
726 [=]() -> TypeVoidInst * { return mBuilder->MakeTypeVoid(); },
727 &mGlobalDefs);
728 }
729
getIntType(int bits,bool isSigned)730 TypeIntInst *GlobalSection::getIntType(int bits, bool isSigned) {
731 if (isSigned) {
732 switch (bits) {
733 #define HANDLE_INT_SIZE(INT_TYPE, BITS, SIGNED) \
734 case BITS: { \
735 return findOrCreate<TypeIntInst>( \
736 [=](TypeIntInst *intTy) -> bool { \
737 return intTy->mOperand1 == BITS && intTy->mOperand2 == SIGNED; \
738 }, \
739 [=]() -> TypeIntInst * { \
740 return mBuilder->MakeTypeInt(BITS, SIGNED); \
741 }, \
742 &mGlobalDefs); \
743 }
744 HANDLE_INT_SIZE(Int, 8, 1);
745 HANDLE_INT_SIZE(Int, 16, 1);
746 HANDLE_INT_SIZE(Int, 32, 1);
747 HANDLE_INT_SIZE(Int, 64, 1);
748 default:
749 Module::errs() << "unexpected int type";
750 }
751 } else {
752 switch (bits) {
753 HANDLE_INT_SIZE(UInt, 8, 0);
754 HANDLE_INT_SIZE(UInt, 16, 0);
755 HANDLE_INT_SIZE(UInt, 32, 0);
756 HANDLE_INT_SIZE(UInt, 64, 0);
757 default:
758 Module::errs() << "unexpected int type";
759 }
760 }
761 #undef HANDLE_INT_SIZE
762 return nullptr;
763 }
764
getFloatType(int bits)765 TypeFloatInst *GlobalSection::getFloatType(int bits) {
766 switch (bits) {
767 #define HANDLE_FLOAT_SIZE(BITS) \
768 case BITS: { \
769 return findOrCreate<TypeFloatInst>( \
770 [=](TypeFloatInst *floatTy) -> bool { \
771 return floatTy->mOperand1 == BITS; \
772 }, \
773 [=]() -> TypeFloatInst * { return mBuilder->MakeTypeFloat(BITS); }, \
774 &mGlobalDefs); \
775 }
776 HANDLE_FLOAT_SIZE(16);
777 HANDLE_FLOAT_SIZE(32);
778 HANDLE_FLOAT_SIZE(64);
779 default:
780 Module::errs() << "unexpeced floating point type";
781 }
782 #undef HANDLE_FLOAT_SIZE
783 return nullptr;
784 }
785
getVectorType(Instruction * componentType,int width)786 TypeVectorInst *GlobalSection::getVectorType(Instruction *componentType,
787 int width) {
788 // TODO: verify that componentType is basic numeric types
789
790 return findOrCreate<TypeVectorInst>(
791 [=](TypeVectorInst *vecTy) -> bool {
792 return vecTy->mOperand1.mInstruction == componentType &&
793 vecTy->mOperand2 == width;
794 },
795 [=]() -> TypeVectorInst * {
796 return mBuilder->MakeTypeVector(componentType, width);
797 },
798 &mGlobalDefs);
799 }
800
getPointerType(StorageClass storage,Instruction * pointeeType)801 TypePointerInst *GlobalSection::getPointerType(StorageClass storage,
802 Instruction *pointeeType) {
803 return findOrCreate<TypePointerInst>(
804 [=](TypePointerInst *type) -> bool {
805 return type->mOperand1 == storage &&
806 type->mOperand2.mInstruction == pointeeType;
807 },
808 [=]() -> TypePointerInst * {
809 return mBuilder->MakeTypePointer(storage, pointeeType);
810 },
811 &mGlobalDefs);
812 }
813
814 TypeRuntimeArrayInst *
getRuntimeArrayType(Instruction * elemType)815 GlobalSection::getRuntimeArrayType(Instruction *elemType) {
816 return findOrCreate<TypeRuntimeArrayInst>(
817 [=](TypeRuntimeArrayInst * /*type*/) -> bool {
818 // return type->mOperand1.mInstruction == elemType;
819 return false;
820 },
821 [=]() -> TypeRuntimeArrayInst * {
822 return mBuilder->MakeTypeRuntimeArray(elemType);
823 },
824 &mGlobalDefs);
825 }
826
getStructType(Instruction * fieldType[],int numField)827 TypeStructInst *GlobalSection::getStructType(Instruction *fieldType[],
828 int numField) {
829 TypeStructInst *structTy = mBuilder->MakeTypeStruct();
830 for (int i = 0; i < numField; i++) {
831 structTy->mOperand1.push_back(fieldType[i]);
832 }
833 mGlobalDefs.push_back(structTy);
834 return structTy;
835 }
836
getFunctionType(Instruction * retType,Instruction * const argType[],size_t numArg)837 TypeFunctionInst *GlobalSection::getFunctionType(Instruction *retType,
838 Instruction *const argType[],
839 size_t numArg) {
840 return findOrCreate<TypeFunctionInst>(
841 [=](TypeFunctionInst *type) -> bool {
842 if (type->mOperand1.mInstruction != retType ||
843 type->mOperand2.size() != numArg) {
844 return false;
845 }
846 for (size_t i = 0; i < numArg; i++) {
847 if (type->mOperand2[i].mInstruction != argType[i]) {
848 return false;
849 }
850 }
851 return true;
852 },
853 [=]() -> TypeFunctionInst * {
854 TypeFunctionInst *funcTy = mBuilder->MakeTypeFunction(retType);
855 for (size_t i = 0; i < numArg; i++) {
856 funcTy->mOperand2.push_back(argType[i]);
857 }
858 return funcTy;
859 },
860 &mGlobalDefs);
861 }
862
addStructType(TypeStructInst * structType)863 GlobalSection *GlobalSection::addStructType(TypeStructInst *structType) {
864 mGlobalDefs.push_back(structType);
865 return this;
866 }
867
addVariable(VariableInst * var)868 GlobalSection *GlobalSection::addVariable(VariableInst *var) {
869 mGlobalDefs.push_back(var);
870 return this;
871 }
872
getInvocationId()873 VariableInst *GlobalSection::getInvocationId() {
874 if (mInvocationId) {
875 return mInvocationId.get();
876 }
877
878 TypeIntInst *UIntTy = getIntType(32, false);
879 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
880 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
881
882 VariableInst *InvocationId =
883 mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
884 InvocationId->decorate(Decoration::BuiltIn)
885 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::GlobalInvocationId));
886
887 mInvocationId.reset(InvocationId);
888
889 return InvocationId;
890 }
891
getNumWorkgroups()892 VariableInst *GlobalSection::getNumWorkgroups() {
893 if (mNumWorkgroups) {
894 return mNumWorkgroups.get();
895 }
896
897 TypeIntInst *UIntTy = getIntType(32, false);
898 TypeVectorInst *V3UIntTy = getVectorType(UIntTy, 3);
899 TypePointerInst *V3UIntPtrTy = getPointerType(StorageClass::Input, V3UIntTy);
900
901 VariableInst *GNum = mBuilder->MakeVariable(V3UIntPtrTy, StorageClass::Input);
902 GNum->decorate(Decoration::BuiltIn)
903 ->addExtraOperand(static_cast<uint32_t>(BuiltIn::NumWorkgroups));
904
905 mNumWorkgroups.reset(GNum);
906
907 return GNum;
908 }
909
DeserializeInternal(InputWordStream & IS)910 bool FunctionDeclaration::DeserializeInternal(InputWordStream &IS) {
911 if (!Deserialize<FunctionInst>(IS)) {
912 return false;
913 }
914
915 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
916
917 if (!Deserialize<FunctionEndInst>(IS)) {
918 return false;
919 }
920
921 return true;
922 }
923
Deserialize(InputWordStream & IS)924 template <> Instruction *Deserialize(InputWordStream &IS) {
925 Instruction *inst;
926
927 switch ((*IS) & 0xFFFF) {
928 #define HANDLE_INSTRUCTION(OPCODE, INST_CLASS) \
929 case OPCODE: \
930 inst = Deserialize<INST_CLASS>(IS); \
931 break;
932 #include "instruction_dispatches_generated.h"
933 #undef HANDLE_INSTRUCTION
934 default:
935 Module::errs() << "unrecognized instruction";
936 inst = nullptr;
937 }
938
939 return inst;
940 }
941
DeserializeInternal(InputWordStream & IS)942 bool Block::DeserializeInternal(InputWordStream &IS) {
943 Instruction *inst;
944 while (((*IS) & 0xFFFF) != OpFunctionEnd &&
945 (inst = Deserialize<Instruction>(IS))) {
946 mInsts.push_back(inst);
947 if (inst->getOpCode() == OpBranch ||
948 inst->getOpCode() == OpBranchConditional ||
949 inst->getOpCode() == OpSwitch || inst->getOpCode() == OpKill ||
950 inst->getOpCode() == OpReturn || inst->getOpCode() == OpReturnValue ||
951 inst->getOpCode() == OpUnreachable) {
952 break;
953 }
954 }
955 return !mInsts.empty();
956 }
957
FunctionDefinition()958 FunctionDefinition::FunctionDefinition()
959 : mParamsDeleter(mParams), mBlocksDeleter(mBlocks) {}
960
FunctionDefinition(Builder * builder,FunctionInst * func,FunctionEndInst * end)961 FunctionDefinition::FunctionDefinition(Builder *builder, FunctionInst *func,
962 FunctionEndInst *end)
963 : Entity(builder), mFunc(func), mFuncEnd(end), mParamsDeleter(mParams),
964 mBlocksDeleter(mBlocks) {}
965
DeserializeInternal(InputWordStream & IS)966 bool FunctionDefinition::DeserializeInternal(InputWordStream &IS) {
967 mFunc.reset(Deserialize<FunctionInst>(IS));
968 if (!mFunc) {
969 return false;
970 }
971
972 DeserializeZeroOrMore<FunctionParameterInst>(IS, mParams);
973 DeserializeZeroOrMore<Block>(IS, mBlocks);
974
975 mFuncEnd.reset(Deserialize<FunctionEndInst>(IS));
976 if (!mFuncEnd) {
977 return false;
978 }
979
980 return true;
981 }
982
getReturnType() const983 Instruction *FunctionDefinition::getReturnType() const {
984 return mFunc->mResultType.mInstruction;
985 }
986
987 } // namespace spirit
988 } // namespace android
989