• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/IRBuilder.h"
8 #include "llvm/IR/LLVMContext.h"
9 #include "llvm/IR/Module.h"
10 #include "llvm/IR/Type.h"
11 #include "llvm/IR/Verifier.h"
12 #include <cctype>
13 #include <cstdio>
14 #include <cstdlib>
15 #include <map>
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 using namespace llvm;
21 
22 //===----------------------------------------------------------------------===//
23 // Lexer
24 //===----------------------------------------------------------------------===//
25 
26 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
27 // of these for known things.
28 enum Token {
29   tok_eof = -1,
30 
31   // commands
32   tok_def = -2,
33   tok_extern = -3,
34 
35   // primary
36   tok_identifier = -4,
37   tok_number = -5
38 };
39 
40 static std::string IdentifierStr; // Filled in if tok_identifier
41 static double NumVal;             // Filled in if tok_number
42 
43 /// gettok - Return the next token from standard input.
gettok()44 static int gettok() {
45   static int LastChar = ' ';
46 
47   // Skip any whitespace.
48   while (isspace(LastChar))
49     LastChar = getchar();
50 
51   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
52     IdentifierStr = LastChar;
53     while (isalnum((LastChar = getchar())))
54       IdentifierStr += LastChar;
55 
56     if (IdentifierStr == "def")
57       return tok_def;
58     if (IdentifierStr == "extern")
59       return tok_extern;
60     return tok_identifier;
61   }
62 
63   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
64     std::string NumStr;
65     do {
66       NumStr += LastChar;
67       LastChar = getchar();
68     } while (isdigit(LastChar) || LastChar == '.');
69 
70     NumVal = strtod(NumStr.c_str(), nullptr);
71     return tok_number;
72   }
73 
74   if (LastChar == '#') {
75     // Comment until end of line.
76     do
77       LastChar = getchar();
78     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
79 
80     if (LastChar != EOF)
81       return gettok();
82   }
83 
84   // Check for end of file.  Don't eat the EOF.
85   if (LastChar == EOF)
86     return tok_eof;
87 
88   // Otherwise, just return the character as its ascii value.
89   int ThisChar = LastChar;
90   LastChar = getchar();
91   return ThisChar;
92 }
93 
94 //===----------------------------------------------------------------------===//
95 // Abstract Syntax Tree (aka Parse Tree)
96 //===----------------------------------------------------------------------===//
97 namespace {
98 /// ExprAST - Base class for all expression nodes.
99 class ExprAST {
100 public:
~ExprAST()101   virtual ~ExprAST() {}
102   virtual Value *codegen() = 0;
103 };
104 
105 /// NumberExprAST - Expression class for numeric literals like "1.0".
106 class NumberExprAST : public ExprAST {
107   double Val;
108 
109 public:
NumberExprAST(double Val)110   NumberExprAST(double Val) : Val(Val) {}
111   Value *codegen() override;
112 };
113 
114 /// VariableExprAST - Expression class for referencing a variable, like "a".
115 class VariableExprAST : public ExprAST {
116   std::string Name;
117 
118 public:
VariableExprAST(const std::string & Name)119   VariableExprAST(const std::string &Name) : Name(Name) {}
120   Value *codegen() override;
121 };
122 
123 /// BinaryExprAST - Expression class for a binary operator.
124 class BinaryExprAST : public ExprAST {
125   char Op;
126   std::unique_ptr<ExprAST> LHS, RHS;
127 
128 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)129   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
130                 std::unique_ptr<ExprAST> RHS)
131       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
132   Value *codegen() override;
133 };
134 
135 /// CallExprAST - Expression class for function calls.
136 class CallExprAST : public ExprAST {
137   std::string Callee;
138   std::vector<std::unique_ptr<ExprAST>> Args;
139 
140 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)141   CallExprAST(const std::string &Callee,
142               std::vector<std::unique_ptr<ExprAST>> Args)
143       : Callee(Callee), Args(std::move(Args)) {}
144   Value *codegen() override;
145 };
146 
147 /// PrototypeAST - This class represents the "prototype" for a function,
148 /// which captures its name, and its argument names (thus implicitly the number
149 /// of arguments the function takes).
150 class PrototypeAST {
151   std::string Name;
152   std::vector<std::string> Args;
153 
154 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args)155   PrototypeAST(const std::string &Name, std::vector<std::string> Args)
156       : Name(Name), Args(std::move(Args)) {}
157   Function *codegen();
getName() const158   const std::string &getName() const { return Name; }
159 };
160 
161 /// FunctionAST - This class represents a function definition itself.
162 class FunctionAST {
163   std::unique_ptr<PrototypeAST> Proto;
164   std::unique_ptr<ExprAST> Body;
165 
166 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)167   FunctionAST(std::unique_ptr<PrototypeAST> Proto,
168               std::unique_ptr<ExprAST> Body)
169       : Proto(std::move(Proto)), Body(std::move(Body)) {}
170   Function *codegen();
171 };
172 } // end anonymous namespace
173 
174 //===----------------------------------------------------------------------===//
175 // Parser
176 //===----------------------------------------------------------------------===//
177 
178 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
179 /// token the parser is looking at.  getNextToken reads another token from the
180 /// lexer and updates CurTok with its results.
181 static int CurTok;
getNextToken()182 static int getNextToken() { return CurTok = gettok(); }
183 
184 /// BinopPrecedence - This holds the precedence for each binary operator that is
185 /// defined.
186 static std::map<char, int> BinopPrecedence;
187 
188 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()189 static int GetTokPrecedence() {
190   if (!isascii(CurTok))
191     return -1;
192 
193   // Make sure it's a declared binop.
194   int TokPrec = BinopPrecedence[CurTok];
195   if (TokPrec <= 0)
196     return -1;
197   return TokPrec;
198 }
199 
200 /// LogError* - These are little helper functions for error handling.
LogError(const char * Str)201 std::unique_ptr<ExprAST> LogError(const char *Str) {
202   fprintf(stderr, "Error: %s\n", Str);
203   return nullptr;
204 }
205 
LogErrorP(const char * Str)206 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
207   LogError(Str);
208   return nullptr;
209 }
210 
211 static std::unique_ptr<ExprAST> ParseExpression();
212 
213 /// numberexpr ::= number
ParseNumberExpr()214 static std::unique_ptr<ExprAST> ParseNumberExpr() {
215   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
216   getNextToken(); // consume the number
217   return std::move(Result);
218 }
219 
220 /// parenexpr ::= '(' expression ')'
ParseParenExpr()221 static std::unique_ptr<ExprAST> ParseParenExpr() {
222   getNextToken(); // eat (.
223   auto V = ParseExpression();
224   if (!V)
225     return nullptr;
226 
227   if (CurTok != ')')
228     return LogError("expected ')'");
229   getNextToken(); // eat ).
230   return V;
231 }
232 
233 /// identifierexpr
234 ///   ::= identifier
235 ///   ::= identifier '(' expression* ')'
ParseIdentifierExpr()236 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
237   std::string IdName = IdentifierStr;
238 
239   getNextToken(); // eat identifier.
240 
241   if (CurTok != '(') // Simple variable ref.
242     return llvm::make_unique<VariableExprAST>(IdName);
243 
244   // Call.
245   getNextToken(); // eat (
246   std::vector<std::unique_ptr<ExprAST>> Args;
247   if (CurTok != ')') {
248     while (true) {
249       if (auto Arg = ParseExpression())
250         Args.push_back(std::move(Arg));
251       else
252         return nullptr;
253 
254       if (CurTok == ')')
255         break;
256 
257       if (CurTok != ',')
258         return LogError("Expected ')' or ',' in argument list");
259       getNextToken();
260     }
261   }
262 
263   // Eat the ')'.
264   getNextToken();
265 
266   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
267 }
268 
269 /// primary
270 ///   ::= identifierexpr
271 ///   ::= numberexpr
272 ///   ::= parenexpr
ParsePrimary()273 static std::unique_ptr<ExprAST> ParsePrimary() {
274   switch (CurTok) {
275   default:
276     return LogError("unknown token when expecting an expression");
277   case tok_identifier:
278     return ParseIdentifierExpr();
279   case tok_number:
280     return ParseNumberExpr();
281   case '(':
282     return ParseParenExpr();
283   }
284 }
285 
286 /// binoprhs
287 ///   ::= ('+' primary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)288 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
289                                               std::unique_ptr<ExprAST> LHS) {
290   // If this is a binop, find its precedence.
291   while (true) {
292     int TokPrec = GetTokPrecedence();
293 
294     // If this is a binop that binds at least as tightly as the current binop,
295     // consume it, otherwise we are done.
296     if (TokPrec < ExprPrec)
297       return LHS;
298 
299     // Okay, we know this is a binop.
300     int BinOp = CurTok;
301     getNextToken(); // eat binop
302 
303     // Parse the primary expression after the binary operator.
304     auto RHS = ParsePrimary();
305     if (!RHS)
306       return nullptr;
307 
308     // If BinOp binds less tightly with RHS than the operator after RHS, let
309     // the pending operator take RHS as its LHS.
310     int NextPrec = GetTokPrecedence();
311     if (TokPrec < NextPrec) {
312       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
313       if (!RHS)
314         return nullptr;
315     }
316 
317     // Merge LHS/RHS.
318     LHS =
319         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
320   }
321 }
322 
323 /// expression
324 ///   ::= primary binoprhs
325 ///
ParseExpression()326 static std::unique_ptr<ExprAST> ParseExpression() {
327   auto LHS = ParsePrimary();
328   if (!LHS)
329     return nullptr;
330 
331   return ParseBinOpRHS(0, std::move(LHS));
332 }
333 
334 /// prototype
335 ///   ::= id '(' id* ')'
ParsePrototype()336 static std::unique_ptr<PrototypeAST> ParsePrototype() {
337   if (CurTok != tok_identifier)
338     return LogErrorP("Expected function name in prototype");
339 
340   std::string FnName = IdentifierStr;
341   getNextToken();
342 
343   if (CurTok != '(')
344     return LogErrorP("Expected '(' in prototype");
345 
346   std::vector<std::string> ArgNames;
347   while (getNextToken() == tok_identifier)
348     ArgNames.push_back(IdentifierStr);
349   if (CurTok != ')')
350     return LogErrorP("Expected ')' in prototype");
351 
352   // success.
353   getNextToken(); // eat ')'.
354 
355   return llvm::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
356 }
357 
358 /// definition ::= 'def' prototype expression
ParseDefinition()359 static std::unique_ptr<FunctionAST> ParseDefinition() {
360   getNextToken(); // eat def.
361   auto Proto = ParsePrototype();
362   if (!Proto)
363     return nullptr;
364 
365   if (auto E = ParseExpression())
366     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
367   return nullptr;
368 }
369 
370 /// toplevelexpr ::= expression
ParseTopLevelExpr()371 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
372   if (auto E = ParseExpression()) {
373     // Make an anonymous proto.
374     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
375                                                  std::vector<std::string>());
376     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
377   }
378   return nullptr;
379 }
380 
381 /// external ::= 'extern' prototype
ParseExtern()382 static std::unique_ptr<PrototypeAST> ParseExtern() {
383   getNextToken(); // eat extern.
384   return ParsePrototype();
385 }
386 
387 //===----------------------------------------------------------------------===//
388 // Code Generation
389 //===----------------------------------------------------------------------===//
390 
391 static LLVMContext TheContext;
392 static IRBuilder<> Builder(TheContext);
393 static std::unique_ptr<Module> TheModule;
394 static std::map<std::string, Value *> NamedValues;
395 
LogErrorV(const char * Str)396 Value *LogErrorV(const char *Str) {
397   LogError(Str);
398   return nullptr;
399 }
400 
codegen()401 Value *NumberExprAST::codegen() {
402   return ConstantFP::get(TheContext, APFloat(Val));
403 }
404 
codegen()405 Value *VariableExprAST::codegen() {
406   // Look this variable up in the function.
407   Value *V = NamedValues[Name];
408   if (!V)
409     return LogErrorV("Unknown variable name");
410   return V;
411 }
412 
codegen()413 Value *BinaryExprAST::codegen() {
414   Value *L = LHS->codegen();
415   Value *R = RHS->codegen();
416   if (!L || !R)
417     return nullptr;
418 
419   switch (Op) {
420   case '+':
421     return Builder.CreateFAdd(L, R, "addtmp");
422   case '-':
423     return Builder.CreateFSub(L, R, "subtmp");
424   case '*':
425     return Builder.CreateFMul(L, R, "multmp");
426   case '<':
427     L = Builder.CreateFCmpULT(L, R, "cmptmp");
428     // Convert bool 0/1 to double 0.0 or 1.0
429     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
430   default:
431     return LogErrorV("invalid binary operator");
432   }
433 }
434 
codegen()435 Value *CallExprAST::codegen() {
436   // Look up the name in the global module table.
437   Function *CalleeF = TheModule->getFunction(Callee);
438   if (!CalleeF)
439     return LogErrorV("Unknown function referenced");
440 
441   // If argument mismatch error.
442   if (CalleeF->arg_size() != Args.size())
443     return LogErrorV("Incorrect # arguments passed");
444 
445   std::vector<Value *> ArgsV;
446   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
447     ArgsV.push_back(Args[i]->codegen());
448     if (!ArgsV.back())
449       return nullptr;
450   }
451 
452   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
453 }
454 
codegen()455 Function *PrototypeAST::codegen() {
456   // Make the function type:  double(double,double) etc.
457   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
458   FunctionType *FT =
459       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
460 
461   Function *F =
462       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
463 
464   // Set names for all arguments.
465   unsigned Idx = 0;
466   for (auto &Arg : F->args())
467     Arg.setName(Args[Idx++]);
468 
469   return F;
470 }
471 
codegen()472 Function *FunctionAST::codegen() {
473   // First, check for an existing function from a previous 'extern' declaration.
474   Function *TheFunction = TheModule->getFunction(Proto->getName());
475 
476   if (!TheFunction)
477     TheFunction = Proto->codegen();
478 
479   if (!TheFunction)
480     return nullptr;
481 
482   // Create a new basic block to start insertion into.
483   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
484   Builder.SetInsertPoint(BB);
485 
486   // Record the function arguments in the NamedValues map.
487   NamedValues.clear();
488   for (auto &Arg : TheFunction->args())
489     NamedValues[Arg.getName()] = &Arg;
490 
491   if (Value *RetVal = Body->codegen()) {
492     // Finish off the function.
493     Builder.CreateRet(RetVal);
494 
495     // Validate the generated code, checking for consistency.
496     verifyFunction(*TheFunction);
497 
498     return TheFunction;
499   }
500 
501   // Error reading body, remove function.
502   TheFunction->eraseFromParent();
503   return nullptr;
504 }
505 
506 //===----------------------------------------------------------------------===//
507 // Top-Level parsing and JIT Driver
508 //===----------------------------------------------------------------------===//
509 
HandleDefinition()510 static void HandleDefinition() {
511   if (auto FnAST = ParseDefinition()) {
512     if (auto *FnIR = FnAST->codegen()) {
513       fprintf(stderr, "Read function definition:");
514       FnIR->dump();
515     }
516   } else {
517     // Skip token for error recovery.
518     getNextToken();
519   }
520 }
521 
HandleExtern()522 static void HandleExtern() {
523   if (auto ProtoAST = ParseExtern()) {
524     if (auto *FnIR = ProtoAST->codegen()) {
525       fprintf(stderr, "Read extern: ");
526       FnIR->dump();
527     }
528   } else {
529     // Skip token for error recovery.
530     getNextToken();
531   }
532 }
533 
HandleTopLevelExpression()534 static void HandleTopLevelExpression() {
535   // Evaluate a top-level expression into an anonymous function.
536   if (auto FnAST = ParseTopLevelExpr()) {
537     if (auto *FnIR = FnAST->codegen()) {
538       fprintf(stderr, "Read top-level expression:");
539       FnIR->dump();
540     }
541   } else {
542     // Skip token for error recovery.
543     getNextToken();
544   }
545 }
546 
547 /// top ::= definition | external | expression | ';'
MainLoop()548 static void MainLoop() {
549   while (true) {
550     fprintf(stderr, "ready> ");
551     switch (CurTok) {
552     case tok_eof:
553       return;
554     case ';': // ignore top-level semicolons.
555       getNextToken();
556       break;
557     case tok_def:
558       HandleDefinition();
559       break;
560     case tok_extern:
561       HandleExtern();
562       break;
563     default:
564       HandleTopLevelExpression();
565       break;
566     }
567   }
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // Main driver code.
572 //===----------------------------------------------------------------------===//
573 
main()574 int main() {
575   // Install standard binary operators.
576   // 1 is lowest precedence.
577   BinopPrecedence['<'] = 10;
578   BinopPrecedence['+'] = 20;
579   BinopPrecedence['-'] = 20;
580   BinopPrecedence['*'] = 40; // highest.
581 
582   // Prime the first token.
583   fprintf(stderr, "ready> ");
584   getNextToken();
585 
586   // Make the module, which holds all the code.
587   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
588 
589   // Run the main "interpreter loop" now.
590   MainLoop();
591 
592   // Print out all of the generated code.
593   TheModule->dump();
594 
595   return 0;
596 }
597