• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- LowerTypeTests.cpp - type metadata lowering pass ------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass lowers type metadata and calls to the llvm.type.test intrinsic.
11 // See http://llvm.org/docs/TypeMetadata.html for more information.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/IPO/LowerTypeTests.h"
16 #include "llvm/Transforms/IPO.h"
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/ADT/Triple.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/GlobalObject.h"
24 #include "llvm/IR/GlobalVariable.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/Instructions.h"
27 #include "llvm/IR/Intrinsics.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/IR/Operator.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
34 
35 using namespace llvm;
36 using namespace lowertypetests;
37 
38 #define DEBUG_TYPE "lowertypetests"
39 
40 STATISTIC(ByteArraySizeBits, "Byte array size in bits");
41 STATISTIC(ByteArraySizeBytes, "Byte array size in bytes");
42 STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
43 STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
44 STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers");
45 
46 static cl::opt<bool> AvoidReuse(
47     "lowertypetests-avoid-reuse",
48     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
49     cl::Hidden, cl::init(true));
50 
containsGlobalOffset(uint64_t Offset) const51 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
52   if (Offset < ByteOffset)
53     return false;
54 
55   if ((Offset - ByteOffset) % (uint64_t(1) << AlignLog2) != 0)
56     return false;
57 
58   uint64_t BitOffset = (Offset - ByteOffset) >> AlignLog2;
59   if (BitOffset >= BitSize)
60     return false;
61 
62   return Bits.count(BitOffset);
63 }
64 
containsValue(const DataLayout & DL,const DenseMap<GlobalObject *,uint64_t> & GlobalLayout,Value * V,uint64_t COffset) const65 bool BitSetInfo::containsValue(
66     const DataLayout &DL,
67     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout, Value *V,
68     uint64_t COffset) const {
69   if (auto GV = dyn_cast<GlobalObject>(V)) {
70     auto I = GlobalLayout.find(GV);
71     if (I == GlobalLayout.end())
72       return false;
73     return containsGlobalOffset(I->second + COffset);
74   }
75 
76   if (auto GEP = dyn_cast<GEPOperator>(V)) {
77     APInt APOffset(DL.getPointerSizeInBits(0), 0);
78     bool Result = GEP->accumulateConstantOffset(DL, APOffset);
79     if (!Result)
80       return false;
81     COffset += APOffset.getZExtValue();
82     return containsValue(DL, GlobalLayout, GEP->getPointerOperand(),
83                          COffset);
84   }
85 
86   if (auto Op = dyn_cast<Operator>(V)) {
87     if (Op->getOpcode() == Instruction::BitCast)
88       return containsValue(DL, GlobalLayout, Op->getOperand(0), COffset);
89 
90     if (Op->getOpcode() == Instruction::Select)
91       return containsValue(DL, GlobalLayout, Op->getOperand(1), COffset) &&
92              containsValue(DL, GlobalLayout, Op->getOperand(2), COffset);
93   }
94 
95   return false;
96 }
97 
print(raw_ostream & OS) const98 void BitSetInfo::print(raw_ostream &OS) const {
99   OS << "offset " << ByteOffset << " size " << BitSize << " align "
100      << (1 << AlignLog2);
101 
102   if (isAllOnes()) {
103     OS << " all-ones\n";
104     return;
105   }
106 
107   OS << " { ";
108   for (uint64_t B : Bits)
109     OS << B << ' ';
110   OS << "}\n";
111 }
112 
build()113 BitSetInfo BitSetBuilder::build() {
114   if (Min > Max)
115     Min = 0;
116 
117   // Normalize each offset against the minimum observed offset, and compute
118   // the bitwise OR of each of the offsets. The number of trailing zeros
119   // in the mask gives us the log2 of the alignment of all offsets, which
120   // allows us to compress the bitset by only storing one bit per aligned
121   // address.
122   uint64_t Mask = 0;
123   for (uint64_t &Offset : Offsets) {
124     Offset -= Min;
125     Mask |= Offset;
126   }
127 
128   BitSetInfo BSI;
129   BSI.ByteOffset = Min;
130 
131   BSI.AlignLog2 = 0;
132   if (Mask != 0)
133     BSI.AlignLog2 = countTrailingZeros(Mask, ZB_Undefined);
134 
135   // Build the compressed bitset while normalizing the offsets against the
136   // computed alignment.
137   BSI.BitSize = ((Max - Min) >> BSI.AlignLog2) + 1;
138   for (uint64_t Offset : Offsets) {
139     Offset >>= BSI.AlignLog2;
140     BSI.Bits.insert(Offset);
141   }
142 
143   return BSI;
144 }
145 
addFragment(const std::set<uint64_t> & F)146 void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) {
147   // Create a new fragment to hold the layout for F.
148   Fragments.emplace_back();
149   std::vector<uint64_t> &Fragment = Fragments.back();
150   uint64_t FragmentIndex = Fragments.size() - 1;
151 
152   for (auto ObjIndex : F) {
153     uint64_t OldFragmentIndex = FragmentMap[ObjIndex];
154     if (OldFragmentIndex == 0) {
155       // We haven't seen this object index before, so just add it to the current
156       // fragment.
157       Fragment.push_back(ObjIndex);
158     } else {
159       // This index belongs to an existing fragment. Copy the elements of the
160       // old fragment into this one and clear the old fragment. We don't update
161       // the fragment map just yet, this ensures that any further references to
162       // indices from the old fragment in this fragment do not insert any more
163       // indices.
164       std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex];
165       Fragment.insert(Fragment.end(), OldFragment.begin(), OldFragment.end());
166       OldFragment.clear();
167     }
168   }
169 
170   // Update the fragment map to point our object indices to this fragment.
171   for (uint64_t ObjIndex : Fragment)
172     FragmentMap[ObjIndex] = FragmentIndex;
173 }
174 
allocate(const std::set<uint64_t> & Bits,uint64_t BitSize,uint64_t & AllocByteOffset,uint8_t & AllocMask)175 void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
176                                 uint64_t BitSize, uint64_t &AllocByteOffset,
177                                 uint8_t &AllocMask) {
178   // Find the smallest current allocation.
179   unsigned Bit = 0;
180   for (unsigned I = 1; I != BitsPerByte; ++I)
181     if (BitAllocs[I] < BitAllocs[Bit])
182       Bit = I;
183 
184   AllocByteOffset = BitAllocs[Bit];
185 
186   // Add our size to it.
187   unsigned ReqSize = AllocByteOffset + BitSize;
188   BitAllocs[Bit] = ReqSize;
189   if (Bytes.size() < ReqSize)
190     Bytes.resize(ReqSize);
191 
192   // Set our bits.
193   AllocMask = 1 << Bit;
194   for (uint64_t B : Bits)
195     Bytes[AllocByteOffset + B] |= AllocMask;
196 }
197 
198 namespace {
199 
200 struct ByteArrayInfo {
201   std::set<uint64_t> Bits;
202   uint64_t BitSize;
203   GlobalVariable *ByteArray;
204   Constant *Mask;
205 };
206 
207 struct LowerTypeTests : public ModulePass {
208   static char ID;
LowerTypeTests__anon2d143e640111::LowerTypeTests209   LowerTypeTests() : ModulePass(ID) {
210     initializeLowerTypeTestsPass(*PassRegistry::getPassRegistry());
211   }
212 
213   Module *M;
214 
215   bool LinkerSubsectionsViaSymbols;
216   Triple::ArchType Arch;
217   Triple::ObjectFormatType ObjectFormat;
218   IntegerType *Int1Ty;
219   IntegerType *Int8Ty;
220   IntegerType *Int32Ty;
221   Type *Int32PtrTy;
222   IntegerType *Int64Ty;
223   IntegerType *IntPtrTy;
224 
225   // Mapping from type identifiers to the call sites that test them.
226   DenseMap<Metadata *, std::vector<CallInst *>> TypeTestCallSites;
227 
228   std::vector<ByteArrayInfo> ByteArrayInfos;
229 
230   BitSetInfo
231   buildBitSet(Metadata *TypeId,
232               const DenseMap<GlobalObject *, uint64_t> &GlobalLayout);
233   ByteArrayInfo *createByteArray(BitSetInfo &BSI);
234   void allocateByteArrays();
235   Value *createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, ByteArrayInfo *&BAI,
236                           Value *BitOffset);
237   void
238   lowerTypeTestCalls(ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
239                      const DenseMap<GlobalObject *, uint64_t> &GlobalLayout);
240   Value *
241   lowerBitSetCall(CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI,
242                   Constant *CombinedGlobal,
243                   const DenseMap<GlobalObject *, uint64_t> &GlobalLayout);
244   void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
245                                        ArrayRef<GlobalVariable *> Globals);
246   unsigned getJumpTableEntrySize();
247   Type *getJumpTableEntryType();
248   Constant *createJumpTableEntry(GlobalObject *Src, Function *Dest,
249                                  unsigned Distance);
250   void verifyTypeMDNode(GlobalObject *GO, MDNode *Type);
251   void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
252                                  ArrayRef<Function *> Functions);
253   void buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
254                                    ArrayRef<GlobalObject *> Globals);
255   bool lower();
256   bool runOnModule(Module &M) override;
257 };
258 
259 } // anonymous namespace
260 
261 INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false,
262                 false)
263 char LowerTypeTests::ID = 0;
264 
createLowerTypeTestsPass()265 ModulePass *llvm::createLowerTypeTestsPass() { return new LowerTypeTests; }
266 
267 /// Build a bit set for TypeId using the object layouts in
268 /// GlobalLayout.
buildBitSet(Metadata * TypeId,const DenseMap<GlobalObject *,uint64_t> & GlobalLayout)269 BitSetInfo LowerTypeTests::buildBitSet(
270     Metadata *TypeId,
271     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) {
272   BitSetBuilder BSB;
273 
274   // Compute the byte offset of each address associated with this type
275   // identifier.
276   SmallVector<MDNode *, 2> Types;
277   for (auto &GlobalAndOffset : GlobalLayout) {
278     Types.clear();
279     GlobalAndOffset.first->getMetadata(LLVMContext::MD_type, Types);
280     for (MDNode *Type : Types) {
281       if (Type->getOperand(1) != TypeId)
282         continue;
283       uint64_t Offset =
284           cast<ConstantInt>(cast<ConstantAsMetadata>(Type->getOperand(0))
285                                 ->getValue())->getZExtValue();
286       BSB.addOffset(GlobalAndOffset.second + Offset);
287     }
288   }
289 
290   return BSB.build();
291 }
292 
293 /// Build a test that bit BitOffset mod sizeof(Bits)*8 is set in
294 /// Bits. This pattern matches to the bt instruction on x86.
createMaskedBitTest(IRBuilder<> & B,Value * Bits,Value * BitOffset)295 static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits,
296                                   Value *BitOffset) {
297   auto BitsType = cast<IntegerType>(Bits->getType());
298   unsigned BitWidth = BitsType->getBitWidth();
299 
300   BitOffset = B.CreateZExtOrTrunc(BitOffset, BitsType);
301   Value *BitIndex =
302       B.CreateAnd(BitOffset, ConstantInt::get(BitsType, BitWidth - 1));
303   Value *BitMask = B.CreateShl(ConstantInt::get(BitsType, 1), BitIndex);
304   Value *MaskedBits = B.CreateAnd(Bits, BitMask);
305   return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0));
306 }
307 
createByteArray(BitSetInfo & BSI)308 ByteArrayInfo *LowerTypeTests::createByteArray(BitSetInfo &BSI) {
309   // Create globals to stand in for byte arrays and masks. These never actually
310   // get initialized, we RAUW and erase them later in allocateByteArrays() once
311   // we know the offset and mask to use.
312   auto ByteArrayGlobal = new GlobalVariable(
313       *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
314   auto MaskGlobal = new GlobalVariable(
315       *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
316 
317   ByteArrayInfos.emplace_back();
318   ByteArrayInfo *BAI = &ByteArrayInfos.back();
319 
320   BAI->Bits = BSI.Bits;
321   BAI->BitSize = BSI.BitSize;
322   BAI->ByteArray = ByteArrayGlobal;
323   BAI->Mask = ConstantExpr::getPtrToInt(MaskGlobal, Int8Ty);
324   return BAI;
325 }
326 
allocateByteArrays()327 void LowerTypeTests::allocateByteArrays() {
328   std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(),
329                    [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) {
330                      return BAI1.BitSize > BAI2.BitSize;
331                    });
332 
333   std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size());
334 
335   ByteArrayBuilder BAB;
336   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
337     ByteArrayInfo *BAI = &ByteArrayInfos[I];
338 
339     uint8_t Mask;
340     BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask);
341 
342     BAI->Mask->replaceAllUsesWith(ConstantInt::get(Int8Ty, Mask));
343     cast<GlobalVariable>(BAI->Mask->getOperand(0))->eraseFromParent();
344   }
345 
346   Constant *ByteArrayConst = ConstantDataArray::get(M->getContext(), BAB.Bytes);
347   auto ByteArray =
348       new GlobalVariable(*M, ByteArrayConst->getType(), /*isConstant=*/true,
349                          GlobalValue::PrivateLinkage, ByteArrayConst);
350 
351   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
352     ByteArrayInfo *BAI = &ByteArrayInfos[I];
353 
354     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
355                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
356     Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
357         ByteArrayConst->getType(), ByteArray, Idxs);
358 
359     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
360     // that the pc-relative displacement is folded into the lea instead of the
361     // test instruction getting another displacement.
362     if (LinkerSubsectionsViaSymbols) {
363       BAI->ByteArray->replaceAllUsesWith(GEP);
364     } else {
365       GlobalAlias *Alias = GlobalAlias::create(
366           Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, M);
367       BAI->ByteArray->replaceAllUsesWith(Alias);
368     }
369     BAI->ByteArray->eraseFromParent();
370   }
371 
372   ByteArraySizeBits = BAB.BitAllocs[0] + BAB.BitAllocs[1] + BAB.BitAllocs[2] +
373                       BAB.BitAllocs[3] + BAB.BitAllocs[4] + BAB.BitAllocs[5] +
374                       BAB.BitAllocs[6] + BAB.BitAllocs[7];
375   ByteArraySizeBytes = BAB.Bytes.size();
376 }
377 
378 /// Build a test that bit BitOffset is set in BSI, where
379 /// BitSetGlobal is a global containing the bits in BSI.
createBitSetTest(IRBuilder<> & B,BitSetInfo & BSI,ByteArrayInfo * & BAI,Value * BitOffset)380 Value *LowerTypeTests::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI,
381                                         ByteArrayInfo *&BAI, Value *BitOffset) {
382   if (BSI.BitSize <= 64) {
383     // If the bit set is sufficiently small, we can avoid a load by bit testing
384     // a constant.
385     IntegerType *BitsTy;
386     if (BSI.BitSize <= 32)
387       BitsTy = Int32Ty;
388     else
389       BitsTy = Int64Ty;
390 
391     uint64_t Bits = 0;
392     for (auto Bit : BSI.Bits)
393       Bits |= uint64_t(1) << Bit;
394     Constant *BitsConst = ConstantInt::get(BitsTy, Bits);
395     return createMaskedBitTest(B, BitsConst, BitOffset);
396   } else {
397     if (!BAI) {
398       ++NumByteArraysCreated;
399       BAI = createByteArray(BSI);
400     }
401 
402     Constant *ByteArray = BAI->ByteArray;
403     Type *Ty = BAI->ByteArray->getValueType();
404     if (!LinkerSubsectionsViaSymbols && AvoidReuse) {
405       // Each use of the byte array uses a different alias. This makes the
406       // backend less likely to reuse previously computed byte array addresses,
407       // improving the security of the CFI mechanism based on this pass.
408       ByteArray = GlobalAlias::create(BAI->ByteArray->getValueType(), 0,
409                                       GlobalValue::PrivateLinkage, "bits_use",
410                                       ByteArray, M);
411     }
412 
413     Value *ByteAddr = B.CreateGEP(Ty, ByteArray, BitOffset);
414     Value *Byte = B.CreateLoad(ByteAddr);
415 
416     Value *ByteAndMask = B.CreateAnd(Byte, BAI->Mask);
417     return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0));
418   }
419 }
420 
421 /// Lower a llvm.type.test call to its implementation. Returns the value to
422 /// replace the call with.
lowerBitSetCall(CallInst * CI,BitSetInfo & BSI,ByteArrayInfo * & BAI,Constant * CombinedGlobalIntAddr,const DenseMap<GlobalObject *,uint64_t> & GlobalLayout)423 Value *LowerTypeTests::lowerBitSetCall(
424     CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI,
425     Constant *CombinedGlobalIntAddr,
426     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) {
427   Value *Ptr = CI->getArgOperand(0);
428   const DataLayout &DL = M->getDataLayout();
429 
430   if (BSI.containsValue(DL, GlobalLayout, Ptr))
431     return ConstantInt::getTrue(M->getContext());
432 
433   Constant *OffsetedGlobalAsInt = ConstantExpr::getAdd(
434       CombinedGlobalIntAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset));
435 
436   BasicBlock *InitialBB = CI->getParent();
437 
438   IRBuilder<> B(CI);
439 
440   Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy);
441 
442   if (BSI.isSingleOffset())
443     return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt);
444 
445   Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt);
446 
447   Value *BitOffset;
448   if (BSI.AlignLog2 == 0) {
449     BitOffset = PtrOffset;
450   } else {
451     // We need to check that the offset both falls within our range and is
452     // suitably aligned. We can check both properties at the same time by
453     // performing a right rotate by log2(alignment) followed by an integer
454     // comparison against the bitset size. The rotate will move the lower
455     // order bits that need to be zero into the higher order bits of the
456     // result, causing the comparison to fail if they are nonzero. The rotate
457     // also conveniently gives us a bit offset to use during the load from
458     // the bitset.
459     Value *OffsetSHR =
460         B.CreateLShr(PtrOffset, ConstantInt::get(IntPtrTy, BSI.AlignLog2));
461     Value *OffsetSHL = B.CreateShl(
462         PtrOffset,
463         ConstantInt::get(IntPtrTy, DL.getPointerSizeInBits(0) - BSI.AlignLog2));
464     BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
465   }
466 
467   Constant *BitSizeConst = ConstantInt::get(IntPtrTy, BSI.BitSize);
468   Value *OffsetInRange = B.CreateICmpULT(BitOffset, BitSizeConst);
469 
470   // If the bit set is all ones, testing against it is unnecessary.
471   if (BSI.isAllOnes())
472     return OffsetInRange;
473 
474   TerminatorInst *Term = SplitBlockAndInsertIfThen(OffsetInRange, CI, false);
475   IRBuilder<> ThenB(Term);
476 
477   // Now that we know that the offset is in range and aligned, load the
478   // appropriate bit from the bitset.
479   Value *Bit = createBitSetTest(ThenB, BSI, BAI, BitOffset);
480 
481   // The value we want is 0 if we came directly from the initial block
482   // (having failed the range or alignment checks), or the loaded bit if
483   // we came from the block in which we loaded it.
484   B.SetInsertPoint(CI);
485   PHINode *P = B.CreatePHI(Int1Ty, 2);
486   P->addIncoming(ConstantInt::get(Int1Ty, 0), InitialBB);
487   P->addIncoming(Bit, ThenB.GetInsertBlock());
488   return P;
489 }
490 
491 /// Given a disjoint set of type identifiers and globals, lay out the globals,
492 /// build the bit sets and lower the llvm.type.test calls.
buildBitSetsFromGlobalVariables(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalVariable * > Globals)493 void LowerTypeTests::buildBitSetsFromGlobalVariables(
494     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalVariable *> Globals) {
495   // Build a new global with the combined contents of the referenced globals.
496   // This global is a struct whose even-indexed elements contain the original
497   // contents of the referenced globals and whose odd-indexed elements contain
498   // any padding required to align the next element to the next power of 2.
499   std::vector<Constant *> GlobalInits;
500   const DataLayout &DL = M->getDataLayout();
501   for (GlobalVariable *G : Globals) {
502     GlobalInits.push_back(G->getInitializer());
503     uint64_t InitSize = DL.getTypeAllocSize(G->getValueType());
504 
505     // Compute the amount of padding required.
506     uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize;
507 
508     // Cap at 128 was found experimentally to have a good data/instruction
509     // overhead tradeoff.
510     if (Padding > 128)
511       Padding = alignTo(InitSize, 128) - InitSize;
512 
513     GlobalInits.push_back(
514         ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
515   }
516   if (!GlobalInits.empty())
517     GlobalInits.pop_back();
518   Constant *NewInit = ConstantStruct::getAnon(M->getContext(), GlobalInits);
519   auto *CombinedGlobal =
520       new GlobalVariable(*M, NewInit->getType(), /*isConstant=*/true,
521                          GlobalValue::PrivateLinkage, NewInit);
522 
523   StructType *NewTy = cast<StructType>(NewInit->getType());
524   const StructLayout *CombinedGlobalLayout = DL.getStructLayout(NewTy);
525 
526   // Compute the offsets of the original globals within the new global.
527   DenseMap<GlobalObject *, uint64_t> GlobalLayout;
528   for (unsigned I = 0; I != Globals.size(); ++I)
529     // Multiply by 2 to account for padding elements.
530     GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2);
531 
532   lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout);
533 
534   // Build aliases pointing to offsets into the combined global for each
535   // global from which we built the combined global, and replace references
536   // to the original globals with references to the aliases.
537   for (unsigned I = 0; I != Globals.size(); ++I) {
538     // Multiply by 2 to account for padding elements.
539     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
540                                       ConstantInt::get(Int32Ty, I * 2)};
541     Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr(
542         NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
543     if (LinkerSubsectionsViaSymbols) {
544       Globals[I]->replaceAllUsesWith(CombinedGlobalElemPtr);
545     } else {
546       assert(Globals[I]->getType()->getAddressSpace() == 0);
547       GlobalAlias *GAlias = GlobalAlias::create(NewTy->getElementType(I * 2), 0,
548                                                 Globals[I]->getLinkage(), "",
549                                                 CombinedGlobalElemPtr, M);
550       GAlias->setVisibility(Globals[I]->getVisibility());
551       GAlias->takeName(Globals[I]);
552       Globals[I]->replaceAllUsesWith(GAlias);
553     }
554     Globals[I]->eraseFromParent();
555   }
556 }
557 
lowerTypeTestCalls(ArrayRef<Metadata * > TypeIds,Constant * CombinedGlobalAddr,const DenseMap<GlobalObject *,uint64_t> & GlobalLayout)558 void LowerTypeTests::lowerTypeTestCalls(
559     ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
560     const DenseMap<GlobalObject *, uint64_t> &GlobalLayout) {
561   Constant *CombinedGlobalIntAddr =
562       ConstantExpr::getPtrToInt(CombinedGlobalAddr, IntPtrTy);
563 
564   // For each type identifier in this disjoint set...
565   for (Metadata *TypeId : TypeIds) {
566     // Build the bitset.
567     BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout);
568     DEBUG({
569       if (auto MDS = dyn_cast<MDString>(TypeId))
570         dbgs() << MDS->getString() << ": ";
571       else
572         dbgs() << "<unnamed>: ";
573       BSI.print(dbgs());
574     });
575 
576     ByteArrayInfo *BAI = nullptr;
577 
578     // Lower each call to llvm.type.test for this type identifier.
579     for (CallInst *CI : TypeTestCallSites[TypeId]) {
580       ++NumTypeTestCallsLowered;
581       Value *Lowered =
582           lowerBitSetCall(CI, BSI, BAI, CombinedGlobalIntAddr, GlobalLayout);
583       CI->replaceAllUsesWith(Lowered);
584       CI->eraseFromParent();
585     }
586   }
587 }
588 
verifyTypeMDNode(GlobalObject * GO,MDNode * Type)589 void LowerTypeTests::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {
590   if (Type->getNumOperands() != 2)
591     report_fatal_error(
592         "All operands of type metadata must have 2 elements");
593 
594   if (GO->isThreadLocal())
595     report_fatal_error("Bit set element may not be thread-local");
596   if (isa<GlobalVariable>(GO) && GO->hasSection())
597     report_fatal_error(
598         "A member of a type identifier may not have an explicit section");
599 
600   if (isa<GlobalVariable>(GO) && GO->isDeclarationForLinker())
601     report_fatal_error(
602         "A global var member of a type identifier must be a definition");
603 
604   auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));
605   if (!OffsetConstMD)
606     report_fatal_error("Type offset must be a constant");
607   auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue());
608   if (!OffsetInt)
609     report_fatal_error("Type offset must be an integer constant");
610 }
611 
612 static const unsigned kX86JumpTableEntrySize = 8;
613 
getJumpTableEntrySize()614 unsigned LowerTypeTests::getJumpTableEntrySize() {
615   if (Arch != Triple::x86 && Arch != Triple::x86_64)
616     report_fatal_error("Unsupported architecture for jump tables");
617 
618   return kX86JumpTableEntrySize;
619 }
620 
621 // Create a constant representing a jump table entry for the target. This
622 // consists of an instruction sequence containing a relative branch to Dest. The
623 // constant will be laid out at address Src+(Len*Distance) where Len is the
624 // target-specific jump table entry size.
createJumpTableEntry(GlobalObject * Src,Function * Dest,unsigned Distance)625 Constant *LowerTypeTests::createJumpTableEntry(GlobalObject *Src,
626                                                Function *Dest,
627                                                unsigned Distance) {
628   if (Arch != Triple::x86 && Arch != Triple::x86_64)
629     report_fatal_error("Unsupported architecture for jump tables");
630 
631   const unsigned kJmpPCRel32Code = 0xe9;
632   const unsigned kInt3Code = 0xcc;
633 
634   ConstantInt *Jmp = ConstantInt::get(Int8Ty, kJmpPCRel32Code);
635 
636   // Build a constant representing the displacement between the constant's
637   // address and Dest. This will resolve to a PC32 relocation referring to Dest.
638   Constant *DestInt = ConstantExpr::getPtrToInt(Dest, IntPtrTy);
639   Constant *SrcInt = ConstantExpr::getPtrToInt(Src, IntPtrTy);
640   Constant *Disp = ConstantExpr::getSub(DestInt, SrcInt);
641   ConstantInt *DispOffset =
642       ConstantInt::get(IntPtrTy, Distance * kX86JumpTableEntrySize + 5);
643   Constant *OffsetedDisp = ConstantExpr::getSub(Disp, DispOffset);
644   OffsetedDisp = ConstantExpr::getTruncOrBitCast(OffsetedDisp, Int32Ty);
645 
646   ConstantInt *Int3 = ConstantInt::get(Int8Ty, kInt3Code);
647 
648   Constant *Fields[] = {
649       Jmp, OffsetedDisp, Int3, Int3, Int3,
650   };
651   return ConstantStruct::getAnon(Fields, /*Packed=*/true);
652 }
653 
getJumpTableEntryType()654 Type *LowerTypeTests::getJumpTableEntryType() {
655   if (Arch != Triple::x86 && Arch != Triple::x86_64)
656     report_fatal_error("Unsupported architecture for jump tables");
657 
658   return StructType::get(M->getContext(),
659                          {Int8Ty, Int32Ty, Int8Ty, Int8Ty, Int8Ty},
660                          /*Packed=*/true);
661 }
662 
663 /// Given a disjoint set of type identifiers and functions, build a jump table
664 /// for the functions, build the bit sets and lower the llvm.type.test calls.
buildBitSetsFromFunctions(ArrayRef<Metadata * > TypeIds,ArrayRef<Function * > Functions)665 void LowerTypeTests::buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
666                                                ArrayRef<Function *> Functions) {
667   // Unlike the global bitset builder, the function bitset builder cannot
668   // re-arrange functions in a particular order and base its calculations on the
669   // layout of the functions' entry points, as we have no idea how large a
670   // particular function will end up being (the size could even depend on what
671   // this pass does!) Instead, we build a jump table, which is a block of code
672   // consisting of one branch instruction for each of the functions in the bit
673   // set that branches to the target function, and redirect any taken function
674   // addresses to the corresponding jump table entry. In the object file's
675   // symbol table, the symbols for the target functions also refer to the jump
676   // table entries, so that addresses taken outside the module will pass any
677   // verification done inside the module.
678   //
679   // In more concrete terms, suppose we have three functions f, g, h which are
680   // of the same type, and a function foo that returns their addresses:
681   //
682   // f:
683   // mov 0, %eax
684   // ret
685   //
686   // g:
687   // mov 1, %eax
688   // ret
689   //
690   // h:
691   // mov 2, %eax
692   // ret
693   //
694   // foo:
695   // mov f, %eax
696   // mov g, %edx
697   // mov h, %ecx
698   // ret
699   //
700   // To create a jump table for these functions, we instruct the LLVM code
701   // generator to output a jump table in the .text section. This is done by
702   // representing the instructions in the jump table as an LLVM constant and
703   // placing them in a global variable in the .text section. The end result will
704   // (conceptually) look like this:
705   //
706   // f:
707   // jmp .Ltmp0 ; 5 bytes
708   // int3       ; 1 byte
709   // int3       ; 1 byte
710   // int3       ; 1 byte
711   //
712   // g:
713   // jmp .Ltmp1 ; 5 bytes
714   // int3       ; 1 byte
715   // int3       ; 1 byte
716   // int3       ; 1 byte
717   //
718   // h:
719   // jmp .Ltmp2 ; 5 bytes
720   // int3       ; 1 byte
721   // int3       ; 1 byte
722   // int3       ; 1 byte
723   //
724   // .Ltmp0:
725   // mov 0, %eax
726   // ret
727   //
728   // .Ltmp1:
729   // mov 1, %eax
730   // ret
731   //
732   // .Ltmp2:
733   // mov 2, %eax
734   // ret
735   //
736   // foo:
737   // mov f, %eax
738   // mov g, %edx
739   // mov h, %ecx
740   // ret
741   //
742   // Because the addresses of f, g, h are evenly spaced at a power of 2, in the
743   // normal case the check can be carried out using the same kind of simple
744   // arithmetic that we normally use for globals.
745 
746   assert(!Functions.empty());
747 
748   // Build a simple layout based on the regular layout of jump tables.
749   DenseMap<GlobalObject *, uint64_t> GlobalLayout;
750   unsigned EntrySize = getJumpTableEntrySize();
751   for (unsigned I = 0; I != Functions.size(); ++I)
752     GlobalLayout[Functions[I]] = I * EntrySize;
753 
754   // Create a constant to hold the jump table.
755   ArrayType *JumpTableType =
756       ArrayType::get(getJumpTableEntryType(), Functions.size());
757   auto JumpTable = new GlobalVariable(*M, JumpTableType,
758                                       /*isConstant=*/true,
759                                       GlobalValue::PrivateLinkage, nullptr);
760   JumpTable->setSection(ObjectFormat == Triple::MachO
761                             ? "__TEXT,__text,regular,pure_instructions"
762                             : ".text");
763   lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
764 
765   // Build aliases pointing to offsets into the jump table, and replace
766   // references to the original functions with references to the aliases.
767   for (unsigned I = 0; I != Functions.size(); ++I) {
768     Constant *CombinedGlobalElemPtr = ConstantExpr::getBitCast(
769         ConstantExpr::getGetElementPtr(
770             JumpTableType, JumpTable,
771             ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
772                                  ConstantInt::get(IntPtrTy, I)}),
773         Functions[I]->getType());
774     if (LinkerSubsectionsViaSymbols || Functions[I]->isDeclarationForLinker()) {
775       Functions[I]->replaceAllUsesWith(CombinedGlobalElemPtr);
776     } else {
777       assert(Functions[I]->getType()->getAddressSpace() == 0);
778       GlobalAlias *GAlias = GlobalAlias::create(Functions[I]->getValueType(), 0,
779                                                 Functions[I]->getLinkage(), "",
780                                                 CombinedGlobalElemPtr, M);
781       GAlias->setVisibility(Functions[I]->getVisibility());
782       GAlias->takeName(Functions[I]);
783       Functions[I]->replaceAllUsesWith(GAlias);
784     }
785     if (!Functions[I]->isDeclarationForLinker())
786       Functions[I]->setLinkage(GlobalValue::PrivateLinkage);
787   }
788 
789   // Build and set the jump table's initializer.
790   std::vector<Constant *> JumpTableEntries;
791   for (unsigned I = 0; I != Functions.size(); ++I)
792     JumpTableEntries.push_back(
793         createJumpTableEntry(JumpTable, Functions[I], I));
794   JumpTable->setInitializer(
795       ConstantArray::get(JumpTableType, JumpTableEntries));
796 }
797 
buildBitSetsFromDisjointSet(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalObject * > Globals)798 void LowerTypeTests::buildBitSetsFromDisjointSet(
799     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalObject *> Globals) {
800   llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices;
801   for (unsigned I = 0; I != TypeIds.size(); ++I)
802     TypeIdIndices[TypeIds[I]] = I;
803 
804   // For each type identifier, build a set of indices that refer to members of
805   // the type identifier.
806   std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size());
807   SmallVector<MDNode *, 2> Types;
808   unsigned GlobalIndex = 0;
809   for (GlobalObject *GO : Globals) {
810     Types.clear();
811     GO->getMetadata(LLVMContext::MD_type, Types);
812     for (MDNode *Type : Types) {
813       // Type = { offset, type identifier }
814       unsigned TypeIdIndex = TypeIdIndices[Type->getOperand(1)];
815       TypeMembers[TypeIdIndex].insert(GlobalIndex);
816     }
817     GlobalIndex++;
818   }
819 
820   // Order the sets of indices by size. The GlobalLayoutBuilder works best
821   // when given small index sets first.
822   std::stable_sort(
823       TypeMembers.begin(), TypeMembers.end(),
824       [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) {
825         return O1.size() < O2.size();
826       });
827 
828   // Create a GlobalLayoutBuilder and provide it with index sets as layout
829   // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as
830   // close together as possible.
831   GlobalLayoutBuilder GLB(Globals.size());
832   for (auto &&MemSet : TypeMembers)
833     GLB.addFragment(MemSet);
834 
835   // Build the bitsets from this disjoint set.
836   if (Globals.empty() || isa<GlobalVariable>(Globals[0])) {
837     // Build a vector of global variables with the computed layout.
838     std::vector<GlobalVariable *> OrderedGVs(Globals.size());
839     auto OGI = OrderedGVs.begin();
840     for (auto &&F : GLB.Fragments) {
841       for (auto &&Offset : F) {
842         auto GV = dyn_cast<GlobalVariable>(Globals[Offset]);
843         if (!GV)
844           report_fatal_error("Type identifier may not contain both global "
845                              "variables and functions");
846         *OGI++ = GV;
847       }
848     }
849 
850     buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs);
851   } else {
852     // Build a vector of functions with the computed layout.
853     std::vector<Function *> OrderedFns(Globals.size());
854     auto OFI = OrderedFns.begin();
855     for (auto &&F : GLB.Fragments) {
856       for (auto &&Offset : F) {
857         auto Fn = dyn_cast<Function>(Globals[Offset]);
858         if (!Fn)
859           report_fatal_error("Type identifier may not contain both global "
860                              "variables and functions");
861         *OFI++ = Fn;
862       }
863     }
864 
865     buildBitSetsFromFunctions(TypeIds, OrderedFns);
866   }
867 }
868 
869 /// Lower all type tests in this module.
lower()870 bool LowerTypeTests::lower() {
871   Function *TypeTestFunc =
872       M->getFunction(Intrinsic::getName(Intrinsic::type_test));
873   if (!TypeTestFunc || TypeTestFunc->use_empty())
874     return false;
875 
876   // Equivalence class set containing type identifiers and the globals that
877   // reference them. This is used to partition the set of type identifiers in
878   // the module into disjoint sets.
879   typedef EquivalenceClasses<PointerUnion<GlobalObject *, Metadata *>>
880       GlobalClassesTy;
881   GlobalClassesTy GlobalClasses;
882 
883   // Verify the type metadata and build a mapping from type identifiers to their
884   // last observed index in the list of globals. This will be used later to
885   // deterministically order the list of type identifiers.
886   llvm::DenseMap<Metadata *, unsigned> TypeIdIndices;
887   unsigned I = 0;
888   SmallVector<MDNode *, 2> Types;
889   for (GlobalObject &GO : M->global_objects()) {
890     Types.clear();
891     GO.getMetadata(LLVMContext::MD_type, Types);
892     for (MDNode *Type : Types) {
893       verifyTypeMDNode(&GO, Type);
894       TypeIdIndices[cast<MDNode>(Type)->getOperand(1)] = ++I;
895     }
896   }
897 
898   for (const Use &U : TypeTestFunc->uses()) {
899     auto CI = cast<CallInst>(U.getUser());
900 
901     auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
902     if (!BitSetMDVal)
903       report_fatal_error(
904           "Second argument of llvm.type.test must be metadata");
905     auto BitSet = BitSetMDVal->getMetadata();
906 
907     // Add the call site to the list of call sites for this type identifier. We
908     // also use TypeTestCallSites to keep track of whether we have seen this
909     // type identifier before. If we have, we don't need to re-add the
910     // referenced globals to the equivalence class.
911     std::pair<DenseMap<Metadata *, std::vector<CallInst *>>::iterator, bool>
912         Ins = TypeTestCallSites.insert(
913             std::make_pair(BitSet, std::vector<CallInst *>()));
914     Ins.first->second.push_back(CI);
915     if (!Ins.second)
916       continue;
917 
918     // Add the type identifier to the equivalence class.
919     GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet);
920     GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
921 
922     // Add the referenced globals to the type identifier's equivalence class.
923     for (GlobalObject &GO : M->global_objects()) {
924       Types.clear();
925       GO.getMetadata(LLVMContext::MD_type, Types);
926       for (MDNode *Type : Types)
927         if (Type->getOperand(1) == BitSet)
928           CurSet = GlobalClasses.unionSets(
929               CurSet, GlobalClasses.findLeader(GlobalClasses.insert(&GO)));
930     }
931   }
932 
933   if (GlobalClasses.empty())
934     return false;
935 
936   // Build a list of disjoint sets ordered by their maximum global index for
937   // determinism.
938   std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets;
939   for (GlobalClassesTy::iterator I = GlobalClasses.begin(),
940                                  E = GlobalClasses.end();
941        I != E; ++I) {
942     if (!I->isLeader()) continue;
943     ++NumTypeIdDisjointSets;
944 
945     unsigned MaxIndex = 0;
946     for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
947          MI != GlobalClasses.member_end(); ++MI) {
948       if ((*MI).is<Metadata *>())
949         MaxIndex = std::max(MaxIndex, TypeIdIndices[MI->get<Metadata *>()]);
950     }
951     Sets.emplace_back(I, MaxIndex);
952   }
953   std::sort(Sets.begin(), Sets.end(),
954             [](const std::pair<GlobalClassesTy::iterator, unsigned> &S1,
955                const std::pair<GlobalClassesTy::iterator, unsigned> &S2) {
956               return S1.second < S2.second;
957             });
958 
959   // For each disjoint set we found...
960   for (const auto &S : Sets) {
961     // Build the list of type identifiers in this disjoint set.
962     std::vector<Metadata *> TypeIds;
963     std::vector<GlobalObject *> Globals;
964     for (GlobalClassesTy::member_iterator MI =
965              GlobalClasses.member_begin(S.first);
966          MI != GlobalClasses.member_end(); ++MI) {
967       if ((*MI).is<Metadata *>())
968         TypeIds.push_back(MI->get<Metadata *>());
969       else
970         Globals.push_back(MI->get<GlobalObject *>());
971     }
972 
973     // Order type identifiers by global index for determinism. This ordering is
974     // stable as there is a one-to-one mapping between metadata and indices.
975     std::sort(TypeIds.begin(), TypeIds.end(), [&](Metadata *M1, Metadata *M2) {
976       return TypeIdIndices[M1] < TypeIdIndices[M2];
977     });
978 
979     // Build bitsets for this disjoint set.
980     buildBitSetsFromDisjointSet(TypeIds, Globals);
981   }
982 
983   allocateByteArrays();
984 
985   return true;
986 }
987 
988 // Initialization helper shared by the old and the new PM.
init(LowerTypeTests * LTT,Module & M)989 static void init(LowerTypeTests *LTT, Module &M) {
990   LTT->M = &M;
991   const DataLayout &DL = M.getDataLayout();
992   Triple TargetTriple(M.getTargetTriple());
993   LTT->LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX();
994   LTT->Arch = TargetTriple.getArch();
995   LTT->ObjectFormat = TargetTriple.getObjectFormat();
996   LTT->Int1Ty = Type::getInt1Ty(M.getContext());
997   LTT->Int8Ty = Type::getInt8Ty(M.getContext());
998   LTT->Int32Ty = Type::getInt32Ty(M.getContext());
999   LTT->Int32PtrTy = PointerType::getUnqual(LTT->Int32Ty);
1000   LTT->Int64Ty = Type::getInt64Ty(M.getContext());
1001   LTT->IntPtrTy = DL.getIntPtrType(M.getContext(), 0);
1002   LTT->TypeTestCallSites.clear();
1003 }
1004 
runOnModule(Module & M)1005 bool LowerTypeTests::runOnModule(Module &M) {
1006   if (skipModule(M))
1007     return false;
1008   init(this, M);
1009   return lower();
1010 }
1011 
run(Module & M,AnalysisManager<Module> & AM)1012 PreservedAnalyses LowerTypeTestsPass::run(Module &M,
1013                                           AnalysisManager<Module> &AM) {
1014   LowerTypeTests Impl;
1015   init(&Impl, M);
1016   bool Changed = Impl.lower();
1017   if (!Changed)
1018     return PreservedAnalyses::all();
1019   return PreservedAnalyses::none();
1020 }
1021