• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- GISelMITest.h --------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #ifndef LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
9 #define LLVM_UNITTEST_CODEGEN_GLOBALISEL_GISELMI_H
10 
11 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h"
12 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
13 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
14 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h"
15 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
16 #include "llvm/CodeGen/GlobalISel/Utils.h"
17 #include "llvm/CodeGen/MIRParser/MIRParser.h"
18 #include "llvm/CodeGen/MachineFunction.h"
19 #include "llvm/CodeGen/MachineModuleInfo.h"
20 #include "llvm/CodeGen/TargetFrameLowering.h"
21 #include "llvm/CodeGen/TargetInstrInfo.h"
22 #include "llvm/CodeGen/TargetLowering.h"
23 #include "llvm/CodeGen/TargetSubtargetInfo.h"
24 #include "llvm/FileCheck/FileCheck.h"
25 #include "llvm/InitializePasses.h"
26 #include "llvm/Support/SourceMgr.h"
27 #include "llvm/Support/TargetRegistry.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Target/TargetMachine.h"
30 #include "llvm/Target/TargetOptions.h"
31 #include "gtest/gtest.h"
32 
33 using namespace llvm;
34 using namespace MIPatternMatch;
35 
initLLVM()36 static inline void initLLVM() {
37   InitializeAllTargets();
38   InitializeAllTargetMCs();
39   InitializeAllAsmPrinters();
40   InitializeAllAsmParsers();
41 
42   PassRegistry *Registry = PassRegistry::getPassRegistry();
43   initializeCore(*Registry);
44   initializeCodeGen(*Registry);
45 }
46 
47 // Define a printers to help debugging when things go wrong.
48 namespace llvm {
49 std::ostream &
50 operator<<(std::ostream &OS, const LLT Ty);
51 
52 std::ostream &
53 operator<<(std::ostream &OS, const MachineFunction &MF);
54 }
55 
parseMIR(LLVMContext & Context,std::unique_ptr<MIRParser> & MIR,const TargetMachine & TM,StringRef MIRCode,const char * FuncName,MachineModuleInfo & MMI)56 static std::unique_ptr<Module> parseMIR(LLVMContext &Context,
57                                         std::unique_ptr<MIRParser> &MIR,
58                                         const TargetMachine &TM,
59                                         StringRef MIRCode, const char *FuncName,
60                                         MachineModuleInfo &MMI) {
61   SMDiagnostic Diagnostic;
62   std::unique_ptr<MemoryBuffer> MBuffer = MemoryBuffer::getMemBuffer(MIRCode);
63   MIR = createMIRParser(std::move(MBuffer), Context);
64   if (!MIR)
65     return nullptr;
66 
67   std::unique_ptr<Module> M = MIR->parseIRModule();
68   if (!M)
69     return nullptr;
70 
71   M->setDataLayout(TM.createDataLayout());
72 
73   if (MIR->parseMachineFunctions(*M, MMI))
74     return nullptr;
75 
76   return M;
77 }
78 static std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
createDummyModule(LLVMContext & Context,const LLVMTargetMachine & TM,StringRef MIRString,const char * FuncName)79 createDummyModule(LLVMContext &Context, const LLVMTargetMachine &TM,
80                   StringRef MIRString, const char *FuncName) {
81   std::unique_ptr<MIRParser> MIR;
82   auto MMI = std::make_unique<MachineModuleInfo>(&TM);
83   std::unique_ptr<Module> M =
84       parseMIR(Context, MIR, TM, MIRString, FuncName, *MMI);
85   return make_pair(std::move(M), std::move(MMI));
86 }
87 
getMFFromMMI(const Module * M,const MachineModuleInfo * MMI)88 static MachineFunction *getMFFromMMI(const Module *M,
89                                      const MachineModuleInfo *MMI) {
90   Function *F = M->getFunction("func");
91   auto *MF = MMI->getMachineFunction(*F);
92   return MF;
93 }
94 
collectCopies(SmallVectorImpl<Register> & Copies,MachineFunction * MF)95 static void collectCopies(SmallVectorImpl<Register> &Copies,
96                           MachineFunction *MF) {
97   for (auto &MBB : *MF)
98     for (MachineInstr &MI : MBB) {
99       if (MI.getOpcode() == TargetOpcode::COPY)
100         Copies.push_back(MI.getOperand(0).getReg());
101     }
102 }
103 
104 class GISelMITest : public ::testing::Test {
105 protected:
GISelMITest()106   GISelMITest() : ::testing::Test() {}
107 
108   /// Prepare a target specific LLVMTargetMachine.
109   virtual std::unique_ptr<LLVMTargetMachine> createTargetMachine() const = 0;
110 
111   /// Get the stub sample MIR test function.
112   virtual void getTargetTestModuleString(SmallString<512> &S,
113                                          StringRef MIRFunc) const = 0;
114 
115   void setUp(StringRef ExtraAssembly = "") {
116     TM = createTargetMachine();
117     if (!TM)
118       return;
119 
120     SmallString<512> MIRString;
121     getTargetTestModuleString(MIRString, ExtraAssembly);
122 
123     ModuleMMIPair = createDummyModule(Context, *TM, MIRString, "func");
124     MF = getMFFromMMI(ModuleMMIPair.first.get(), ModuleMMIPair.second.get());
125     collectCopies(Copies, MF);
126     EntryMBB = &*MF->begin();
127     B.setMF(*MF);
128     MRI = &MF->getRegInfo();
129     B.setInsertPt(*EntryMBB, EntryMBB->end());
130   }
131 
132   LLVMContext Context;
133   std::unique_ptr<LLVMTargetMachine> TM;
134   MachineFunction *MF;
135   std::pair<std::unique_ptr<Module>, std::unique_ptr<MachineModuleInfo>>
136       ModuleMMIPair;
137   SmallVector<Register, 4> Copies;
138   MachineBasicBlock *EntryMBB;
139   MachineIRBuilder B;
140   MachineRegisterInfo *MRI;
141 };
142 
143 class AArch64GISelMITest : public GISelMITest {
144   std::unique_ptr<LLVMTargetMachine> createTargetMachine() const override;
145   void getTargetTestModuleString(SmallString<512> &S,
146                                  StringRef MIRFunc) const override;
147 };
148 
149 class AMDGPUGISelMITest : public GISelMITest {
150   std::unique_ptr<LLVMTargetMachine> createTargetMachine() const override;
151   void getTargetTestModuleString(SmallString<512> &S,
152                                  StringRef MIRFunc) const override;
153 };
154 
155 #define DefineLegalizerInfo(Name, SettingUpActionsBlock)                       \
156   class Name##Info : public LegalizerInfo {                                    \
157   public:                                                                      \
158     Name##Info(const TargetSubtargetInfo &ST) {                                \
159       using namespace TargetOpcode;                                            \
160       const LLT s8 = LLT::scalar(8);                                           \
161       (void)s8;                                                                \
162       const LLT s16 = LLT::scalar(16);                                         \
163       (void)s16;                                                               \
164       const LLT s32 = LLT::scalar(32);                                         \
165       (void)s32;                                                               \
166       const LLT s64 = LLT::scalar(64);                                         \
167       (void)s64;                                                               \
168       const LLT s128 = LLT::scalar(128);                                       \
169       (void)s128;                                                              \
170       do                                                                       \
171         SettingUpActionsBlock while (0);                                       \
172       computeTables();                                                         \
173       verify(*ST.getInstrInfo());                                              \
174     }                                                                          \
175   };
176 
CheckMachineFunction(const MachineFunction & MF,StringRef CheckStr)177 static inline bool CheckMachineFunction(const MachineFunction &MF,
178                                         StringRef CheckStr) {
179   SmallString<512> Msg;
180   raw_svector_ostream OS(Msg);
181   MF.print(OS);
182   auto OutputBuf = MemoryBuffer::getMemBuffer(Msg, "Output", false);
183   auto CheckBuf = MemoryBuffer::getMemBuffer(CheckStr, "");
184   SmallString<4096> CheckFileBuffer;
185   FileCheckRequest Req;
186   FileCheck FC(Req);
187   StringRef CheckFileText =
188       FC.CanonicalizeFile(*CheckBuf.get(), CheckFileBuffer);
189   SourceMgr SM;
190   SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(CheckFileText, "CheckFile"),
191                         SMLoc());
192   Regex PrefixRE = FC.buildCheckPrefixRegex();
193   if (FC.readCheckFile(SM, CheckFileText, PrefixRE))
194     return false;
195 
196   auto OutBuffer = OutputBuf->getBuffer();
197   SM.AddNewSourceBuffer(std::move(OutputBuf), SMLoc());
198   return FC.checkInput(SM, OutBuffer);
199 }
200 #endif
201