• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include "llvm/ADT/APFloat.h"
2 #include "llvm/ADT/STLExtras.h"
3 #include "llvm/IR/BasicBlock.h"
4 #include "llvm/IR/Constants.h"
5 #include "llvm/IR/DerivedTypes.h"
6 #include "llvm/IR/Function.h"
7 #include "llvm/IR/Instructions.h"
8 #include "llvm/IR/IRBuilder.h"
9 #include "llvm/IR/LLVMContext.h"
10 #include "llvm/IR/LegacyPassManager.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/IR/Type.h"
13 #include "llvm/IR/Verifier.h"
14 #include "llvm/Support/Error.h"
15 #include "llvm/Support/TargetSelect.h"
16 #include "llvm/Target/TargetMachine.h"
17 #include "llvm/Transforms/Scalar.h"
18 #include "llvm/Transforms/Scalar/GVN.h"
19 #include "KaleidoscopeJIT.h"
20 #include <cassert>
21 #include <cctype>
22 #include <cstdint>
23 #include <cstdio>
24 #include <cstdlib>
25 #include <map>
26 #include <memory>
27 #include <string>
28 #include <utility>
29 #include <vector>
30 
31 using namespace llvm;
32 using namespace llvm::orc;
33 
34 //===----------------------------------------------------------------------===//
35 // Lexer
36 //===----------------------------------------------------------------------===//
37 
38 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
39 // of these for known things.
40 enum Token {
41   tok_eof = -1,
42 
43   // commands
44   tok_def = -2,
45   tok_extern = -3,
46 
47   // primary
48   tok_identifier = -4,
49   tok_number = -5,
50 
51   // control
52   tok_if = -6,
53   tok_then = -7,
54   tok_else = -8,
55   tok_for = -9,
56   tok_in = -10,
57 
58   // operators
59   tok_binary = -11,
60   tok_unary = -12,
61 
62   // var definition
63   tok_var = -13
64 };
65 
66 static std::string IdentifierStr; // Filled in if tok_identifier
67 static double NumVal;             // Filled in if tok_number
68 
69 /// gettok - Return the next token from standard input.
gettok()70 static int gettok() {
71   static int LastChar = ' ';
72 
73   // Skip any whitespace.
74   while (isspace(LastChar))
75     LastChar = getchar();
76 
77   if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
78     IdentifierStr = LastChar;
79     while (isalnum((LastChar = getchar())))
80       IdentifierStr += LastChar;
81 
82     if (IdentifierStr == "def")
83       return tok_def;
84     if (IdentifierStr == "extern")
85       return tok_extern;
86     if (IdentifierStr == "if")
87       return tok_if;
88     if (IdentifierStr == "then")
89       return tok_then;
90     if (IdentifierStr == "else")
91       return tok_else;
92     if (IdentifierStr == "for")
93       return tok_for;
94     if (IdentifierStr == "in")
95       return tok_in;
96     if (IdentifierStr == "binary")
97       return tok_binary;
98     if (IdentifierStr == "unary")
99       return tok_unary;
100     if (IdentifierStr == "var")
101       return tok_var;
102     return tok_identifier;
103   }
104 
105   if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
106     std::string NumStr;
107     do {
108       NumStr += LastChar;
109       LastChar = getchar();
110     } while (isdigit(LastChar) || LastChar == '.');
111 
112     NumVal = strtod(NumStr.c_str(), nullptr);
113     return tok_number;
114   }
115 
116   if (LastChar == '#') {
117     // Comment until end of line.
118     do
119       LastChar = getchar();
120     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
121 
122     if (LastChar != EOF)
123       return gettok();
124   }
125 
126   // Check for end of file.  Don't eat the EOF.
127   if (LastChar == EOF)
128     return tok_eof;
129 
130   // Otherwise, just return the character as its ascii value.
131   int ThisChar = LastChar;
132   LastChar = getchar();
133   return ThisChar;
134 }
135 
136 //===----------------------------------------------------------------------===//
137 // Abstract Syntax Tree (aka Parse Tree)
138 //===----------------------------------------------------------------------===//
139 
140 /// ExprAST - Base class for all expression nodes.
141 class ExprAST {
142 public:
~ExprAST()143   virtual ~ExprAST() {}
144   virtual Value *codegen() = 0;
145 };
146 
147 /// NumberExprAST - Expression class for numeric literals like "1.0".
148 class NumberExprAST : public ExprAST {
149   double Val;
150 
151 public:
NumberExprAST(double Val)152   NumberExprAST(double Val) : Val(Val) {}
153   Value *codegen() override;
154 };
155 
156 /// VariableExprAST - Expression class for referencing a variable, like "a".
157 class VariableExprAST : public ExprAST {
158   std::string Name;
159 
160 public:
VariableExprAST(const std::string & Name)161   VariableExprAST(const std::string &Name) : Name(Name) {}
getName() const162   const std::string &getName() const { return Name; }
163   Value *codegen() override;
164 };
165 
166 /// UnaryExprAST - Expression class for a unary operator.
167 class UnaryExprAST : public ExprAST {
168   char Opcode;
169   std::unique_ptr<ExprAST> Operand;
170 
171 public:
UnaryExprAST(char Opcode,std::unique_ptr<ExprAST> Operand)172   UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
173       : Opcode(Opcode), Operand(std::move(Operand)) {}
174   Value *codegen() override;
175 };
176 
177 /// BinaryExprAST - Expression class for a binary operator.
178 class BinaryExprAST : public ExprAST {
179   char Op;
180   std::unique_ptr<ExprAST> LHS, RHS;
181 
182 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)183   BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
184                 std::unique_ptr<ExprAST> RHS)
185       : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
186   Value *codegen() override;
187 };
188 
189 /// CallExprAST - Expression class for function calls.
190 class CallExprAST : public ExprAST {
191   std::string Callee;
192   std::vector<std::unique_ptr<ExprAST>> Args;
193 
194 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)195   CallExprAST(const std::string &Callee,
196               std::vector<std::unique_ptr<ExprAST>> Args)
197       : Callee(Callee), Args(std::move(Args)) {}
198   Value *codegen() override;
199 };
200 
201 /// IfExprAST - Expression class for if/then/else.
202 class IfExprAST : public ExprAST {
203   std::unique_ptr<ExprAST> Cond, Then, Else;
204 
205 public:
IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)206   IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
207             std::unique_ptr<ExprAST> Else)
208       : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
209   Value *codegen() override;
210 };
211 
212 /// ForExprAST - Expression class for for/in.
213 class ForExprAST : public ExprAST {
214   std::string VarName;
215   std::unique_ptr<ExprAST> Start, End, Step, Body;
216 
217 public:
ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)218   ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
219              std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
220              std::unique_ptr<ExprAST> Body)
221       : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
222         Step(std::move(Step)), Body(std::move(Body)) {}
223   Value *codegen() override;
224 };
225 
226 /// VarExprAST - Expression class for var/in
227 class VarExprAST : public ExprAST {
228   std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
229   std::unique_ptr<ExprAST> Body;
230 
231 public:
VarExprAST(std::vector<std::pair<std::string,std::unique_ptr<ExprAST>>> VarNames,std::unique_ptr<ExprAST> Body)232   VarExprAST(
233       std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames,
234       std::unique_ptr<ExprAST> Body)
235       : VarNames(std::move(VarNames)), Body(std::move(Body)) {}
236   Value *codegen() override;
237 };
238 
239 /// PrototypeAST - This class represents the "prototype" for a function,
240 /// which captures its name, and its argument names (thus implicitly the number
241 /// of arguments the function takes), as well as if it is an operator.
242 class PrototypeAST {
243   std::string Name;
244   std::vector<std::string> Args;
245   bool IsOperator;
246   unsigned Precedence; // Precedence if a binary op.
247 
248 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args,bool IsOperator=false,unsigned Prec=0)249   PrototypeAST(const std::string &Name, std::vector<std::string> Args,
250                bool IsOperator = false, unsigned Prec = 0)
251       : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
252         Precedence(Prec) {}
253   Function *codegen();
getName() const254   const std::string &getName() const { return Name; }
255 
isUnaryOp() const256   bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
isBinaryOp() const257   bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
258 
getOperatorName() const259   char getOperatorName() const {
260     assert(isUnaryOp() || isBinaryOp());
261     return Name[Name.size() - 1];
262   }
263 
getBinaryPrecedence() const264   unsigned getBinaryPrecedence() const { return Precedence; }
265 };
266 
267 //===----------------------------------------------------------------------===//
268 // Parser
269 //===----------------------------------------------------------------------===//
270 
271 /// CurTok/getNextToken - Provide a simple token buffer.  CurTok is the current
272 /// token the parser is looking at.  getNextToken reads another token from the
273 /// lexer and updates CurTok with its results.
274 static int CurTok;
getNextToken()275 static int getNextToken() { return CurTok = gettok(); }
276 
277 /// BinopPrecedence - This holds the precedence for each binary operator that is
278 /// defined.
279 static std::map<char, int> BinopPrecedence;
280 
281 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()282 static int GetTokPrecedence() {
283   if (!isascii(CurTok))
284     return -1;
285 
286   // Make sure it's a declared binop.
287   int TokPrec = BinopPrecedence[CurTok];
288   if (TokPrec <= 0)
289     return -1;
290   return TokPrec;
291 }
292 
293 /// LogError* - These are little helper functions for error handling.
LogError(const char * Str)294 std::unique_ptr<ExprAST> LogError(const char *Str) {
295   fprintf(stderr, "Error: %s\n", Str);
296   return nullptr;
297 }
298 
LogErrorP(const char * Str)299 std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
300   LogError(Str);
301   return nullptr;
302 }
303 
304 static std::unique_ptr<ExprAST> ParseExpression();
305 
306 /// numberexpr ::= number
ParseNumberExpr()307 static std::unique_ptr<ExprAST> ParseNumberExpr() {
308   auto Result = llvm::make_unique<NumberExprAST>(NumVal);
309   getNextToken(); // consume the number
310   return std::move(Result);
311 }
312 
313 /// parenexpr ::= '(' expression ')'
ParseParenExpr()314 static std::unique_ptr<ExprAST> ParseParenExpr() {
315   getNextToken(); // eat (.
316   auto V = ParseExpression();
317   if (!V)
318     return nullptr;
319 
320   if (CurTok != ')')
321     return LogError("expected ')'");
322   getNextToken(); // eat ).
323   return V;
324 }
325 
326 /// identifierexpr
327 ///   ::= identifier
328 ///   ::= identifier '(' expression* ')'
ParseIdentifierExpr()329 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
330   std::string IdName = IdentifierStr;
331 
332   getNextToken(); // eat identifier.
333 
334   if (CurTok != '(') // Simple variable ref.
335     return llvm::make_unique<VariableExprAST>(IdName);
336 
337   // Call.
338   getNextToken(); // eat (
339   std::vector<std::unique_ptr<ExprAST>> Args;
340   if (CurTok != ')') {
341     while (true) {
342       if (auto Arg = ParseExpression())
343         Args.push_back(std::move(Arg));
344       else
345         return nullptr;
346 
347       if (CurTok == ')')
348         break;
349 
350       if (CurTok != ',')
351         return LogError("Expected ')' or ',' in argument list");
352       getNextToken();
353     }
354   }
355 
356   // Eat the ')'.
357   getNextToken();
358 
359   return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
360 }
361 
362 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
ParseIfExpr()363 static std::unique_ptr<ExprAST> ParseIfExpr() {
364   getNextToken(); // eat the if.
365 
366   // condition.
367   auto Cond = ParseExpression();
368   if (!Cond)
369     return nullptr;
370 
371   if (CurTok != tok_then)
372     return LogError("expected then");
373   getNextToken(); // eat the then
374 
375   auto Then = ParseExpression();
376   if (!Then)
377     return nullptr;
378 
379   if (CurTok != tok_else)
380     return LogError("expected else");
381 
382   getNextToken();
383 
384   auto Else = ParseExpression();
385   if (!Else)
386     return nullptr;
387 
388   return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
389                                       std::move(Else));
390 }
391 
392 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
ParseForExpr()393 static std::unique_ptr<ExprAST> ParseForExpr() {
394   getNextToken(); // eat the for.
395 
396   if (CurTok != tok_identifier)
397     return LogError("expected identifier after for");
398 
399   std::string IdName = IdentifierStr;
400   getNextToken(); // eat identifier.
401 
402   if (CurTok != '=')
403     return LogError("expected '=' after for");
404   getNextToken(); // eat '='.
405 
406   auto Start = ParseExpression();
407   if (!Start)
408     return nullptr;
409   if (CurTok != ',')
410     return LogError("expected ',' after for start value");
411   getNextToken();
412 
413   auto End = ParseExpression();
414   if (!End)
415     return nullptr;
416 
417   // The step value is optional.
418   std::unique_ptr<ExprAST> Step;
419   if (CurTok == ',') {
420     getNextToken();
421     Step = ParseExpression();
422     if (!Step)
423       return nullptr;
424   }
425 
426   if (CurTok != tok_in)
427     return LogError("expected 'in' after for");
428   getNextToken(); // eat 'in'.
429 
430   auto Body = ParseExpression();
431   if (!Body)
432     return nullptr;
433 
434   return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
435                                        std::move(Step), std::move(Body));
436 }
437 
438 /// varexpr ::= 'var' identifier ('=' expression)?
439 //                    (',' identifier ('=' expression)?)* 'in' expression
ParseVarExpr()440 static std::unique_ptr<ExprAST> ParseVarExpr() {
441   getNextToken(); // eat the var.
442 
443   std::vector<std::pair<std::string, std::unique_ptr<ExprAST>>> VarNames;
444 
445   // At least one variable name is required.
446   if (CurTok != tok_identifier)
447     return LogError("expected identifier after var");
448 
449   while (true) {
450     std::string Name = IdentifierStr;
451     getNextToken(); // eat identifier.
452 
453     // Read the optional initializer.
454     std::unique_ptr<ExprAST> Init = nullptr;
455     if (CurTok == '=') {
456       getNextToken(); // eat the '='.
457 
458       Init = ParseExpression();
459       if (!Init)
460         return nullptr;
461     }
462 
463     VarNames.push_back(std::make_pair(Name, std::move(Init)));
464 
465     // End of var list, exit loop.
466     if (CurTok != ',')
467       break;
468     getNextToken(); // eat the ','.
469 
470     if (CurTok != tok_identifier)
471       return LogError("expected identifier list after var");
472   }
473 
474   // At this point, we have to have 'in'.
475   if (CurTok != tok_in)
476     return LogError("expected 'in' keyword after 'var'");
477   getNextToken(); // eat 'in'.
478 
479   auto Body = ParseExpression();
480   if (!Body)
481     return nullptr;
482 
483   return llvm::make_unique<VarExprAST>(std::move(VarNames), std::move(Body));
484 }
485 
486 /// primary
487 ///   ::= identifierexpr
488 ///   ::= numberexpr
489 ///   ::= parenexpr
490 ///   ::= ifexpr
491 ///   ::= forexpr
492 ///   ::= varexpr
ParsePrimary()493 static std::unique_ptr<ExprAST> ParsePrimary() {
494   switch (CurTok) {
495   default:
496     return LogError("unknown token when expecting an expression");
497   case tok_identifier:
498     return ParseIdentifierExpr();
499   case tok_number:
500     return ParseNumberExpr();
501   case '(':
502     return ParseParenExpr();
503   case tok_if:
504     return ParseIfExpr();
505   case tok_for:
506     return ParseForExpr();
507   case tok_var:
508     return ParseVarExpr();
509   }
510 }
511 
512 /// unary
513 ///   ::= primary
514 ///   ::= '!' unary
ParseUnary()515 static std::unique_ptr<ExprAST> ParseUnary() {
516   // If the current token is not an operator, it must be a primary expr.
517   if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
518     return ParsePrimary();
519 
520   // If this is a unary operator, read it.
521   int Opc = CurTok;
522   getNextToken();
523   if (auto Operand = ParseUnary())
524     return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
525   return nullptr;
526 }
527 
528 /// binoprhs
529 ///   ::= ('+' unary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)530 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
531                                               std::unique_ptr<ExprAST> LHS) {
532   // If this is a binop, find its precedence.
533   while (true) {
534     int TokPrec = GetTokPrecedence();
535 
536     // If this is a binop that binds at least as tightly as the current binop,
537     // consume it, otherwise we are done.
538     if (TokPrec < ExprPrec)
539       return LHS;
540 
541     // Okay, we know this is a binop.
542     int BinOp = CurTok;
543     getNextToken(); // eat binop
544 
545     // Parse the unary expression after the binary operator.
546     auto RHS = ParseUnary();
547     if (!RHS)
548       return nullptr;
549 
550     // If BinOp binds less tightly with RHS than the operator after RHS, let
551     // the pending operator take RHS as its LHS.
552     int NextPrec = GetTokPrecedence();
553     if (TokPrec < NextPrec) {
554       RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
555       if (!RHS)
556         return nullptr;
557     }
558 
559     // Merge LHS/RHS.
560     LHS =
561         llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
562   }
563 }
564 
565 /// expression
566 ///   ::= unary binoprhs
567 ///
ParseExpression()568 static std::unique_ptr<ExprAST> ParseExpression() {
569   auto LHS = ParseUnary();
570   if (!LHS)
571     return nullptr;
572 
573   return ParseBinOpRHS(0, std::move(LHS));
574 }
575 
576 /// prototype
577 ///   ::= id '(' id* ')'
578 ///   ::= binary LETTER number? (id, id)
579 ///   ::= unary LETTER (id)
ParsePrototype()580 static std::unique_ptr<PrototypeAST> ParsePrototype() {
581   std::string FnName;
582 
583   unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
584   unsigned BinaryPrecedence = 30;
585 
586   switch (CurTok) {
587   default:
588     return LogErrorP("Expected function name in prototype");
589   case tok_identifier:
590     FnName = IdentifierStr;
591     Kind = 0;
592     getNextToken();
593     break;
594   case tok_unary:
595     getNextToken();
596     if (!isascii(CurTok))
597       return LogErrorP("Expected unary operator");
598     FnName = "unary";
599     FnName += (char)CurTok;
600     Kind = 1;
601     getNextToken();
602     break;
603   case tok_binary:
604     getNextToken();
605     if (!isascii(CurTok))
606       return LogErrorP("Expected binary operator");
607     FnName = "binary";
608     FnName += (char)CurTok;
609     Kind = 2;
610     getNextToken();
611 
612     // Read the precedence if present.
613     if (CurTok == tok_number) {
614       if (NumVal < 1 || NumVal > 100)
615         return LogErrorP("Invalid precedecnce: must be 1..100");
616       BinaryPrecedence = (unsigned)NumVal;
617       getNextToken();
618     }
619     break;
620   }
621 
622   if (CurTok != '(')
623     return LogErrorP("Expected '(' in prototype");
624 
625   std::vector<std::string> ArgNames;
626   while (getNextToken() == tok_identifier)
627     ArgNames.push_back(IdentifierStr);
628   if (CurTok != ')')
629     return LogErrorP("Expected ')' in prototype");
630 
631   // success.
632   getNextToken(); // eat ')'.
633 
634   // Verify right number of names for operator.
635   if (Kind && ArgNames.size() != Kind)
636     return LogErrorP("Invalid number of operands for operator");
637 
638   return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
639                                          BinaryPrecedence);
640 }
641 
642 /// definition ::= 'def' prototype expression
ParseDefinition()643 static std::unique_ptr<FunctionAST> ParseDefinition() {
644   getNextToken(); // eat def.
645   auto Proto = ParsePrototype();
646   if (!Proto)
647     return nullptr;
648 
649   if (auto E = ParseExpression())
650     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
651   return nullptr;
652 }
653 
654 /// toplevelexpr ::= expression
ParseTopLevelExpr()655 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
656   if (auto E = ParseExpression()) {
657     // Make an anonymous proto.
658     auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
659                                                  std::vector<std::string>());
660     return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
661   }
662   return nullptr;
663 }
664 
665 /// external ::= 'extern' prototype
ParseExtern()666 static std::unique_ptr<PrototypeAST> ParseExtern() {
667   getNextToken(); // eat extern.
668   return ParsePrototype();
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // Code Generation
673 //===----------------------------------------------------------------------===//
674 
675 static LLVMContext TheContext;
676 static IRBuilder<> Builder(TheContext);
677 static std::unique_ptr<Module> TheModule;
678 static std::map<std::string, AllocaInst *> NamedValues;
679 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
680 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
681 static ExitOnError ExitOnErr;
682 
LogErrorV(const char * Str)683 Value *LogErrorV(const char *Str) {
684   LogError(Str);
685   return nullptr;
686 }
687 
getFunction(std::string Name)688 Function *getFunction(std::string Name) {
689   // First, see if the function has already been added to the current module.
690   if (auto *F = TheModule->getFunction(Name))
691     return F;
692 
693   // If not, check whether we can codegen the declaration from some existing
694   // prototype.
695   auto FI = FunctionProtos.find(Name);
696   if (FI != FunctionProtos.end())
697     return FI->second->codegen();
698 
699   // If no existing prototype exists, return null.
700   return nullptr;
701 }
702 
703 /// CreateEntryBlockAlloca - Create an alloca instruction in the entry block of
704 /// the function.  This is used for mutable variables etc.
CreateEntryBlockAlloca(Function * TheFunction,const std::string & VarName)705 static AllocaInst *CreateEntryBlockAlloca(Function *TheFunction,
706                                           const std::string &VarName) {
707   IRBuilder<> TmpB(&TheFunction->getEntryBlock(),
708                    TheFunction->getEntryBlock().begin());
709   return TmpB.CreateAlloca(Type::getDoubleTy(TheContext), nullptr, VarName);
710 }
711 
codegen()712 Value *NumberExprAST::codegen() {
713   return ConstantFP::get(TheContext, APFloat(Val));
714 }
715 
codegen()716 Value *VariableExprAST::codegen() {
717   // Look this variable up in the function.
718   Value *V = NamedValues[Name];
719   if (!V)
720     return LogErrorV("Unknown variable name");
721 
722   // Load the value.
723   return Builder.CreateLoad(V, Name.c_str());
724 }
725 
codegen()726 Value *UnaryExprAST::codegen() {
727   Value *OperandV = Operand->codegen();
728   if (!OperandV)
729     return nullptr;
730 
731   Function *F = getFunction(std::string("unary") + Opcode);
732   if (!F)
733     return LogErrorV("Unknown unary operator");
734 
735   return Builder.CreateCall(F, OperandV, "unop");
736 }
737 
codegen()738 Value *BinaryExprAST::codegen() {
739   // Special case '=' because we don't want to emit the LHS as an expression.
740   if (Op == '=') {
741     // Assignment requires the LHS to be an identifier.
742     // This assume we're building without RTTI because LLVM builds that way by
743     // default.  If you build LLVM with RTTI this can be changed to a
744     // dynamic_cast for automatic error checking.
745     VariableExprAST *LHSE = static_cast<VariableExprAST *>(LHS.get());
746     if (!LHSE)
747       return LogErrorV("destination of '=' must be a variable");
748     // Codegen the RHS.
749     Value *Val = RHS->codegen();
750     if (!Val)
751       return nullptr;
752 
753     // Look up the name.
754     Value *Variable = NamedValues[LHSE->getName()];
755     if (!Variable)
756       return LogErrorV("Unknown variable name");
757 
758     Builder.CreateStore(Val, Variable);
759     return Val;
760   }
761 
762   Value *L = LHS->codegen();
763   Value *R = RHS->codegen();
764   if (!L || !R)
765     return nullptr;
766 
767   switch (Op) {
768   case '+':
769     return Builder.CreateFAdd(L, R, "addtmp");
770   case '-':
771     return Builder.CreateFSub(L, R, "subtmp");
772   case '*':
773     return Builder.CreateFMul(L, R, "multmp");
774   case '<':
775     L = Builder.CreateFCmpULT(L, R, "cmptmp");
776     // Convert bool 0/1 to double 0.0 or 1.0
777     return Builder.CreateUIToFP(L, Type::getDoubleTy(TheContext), "booltmp");
778   default:
779     break;
780   }
781 
782   // If it wasn't a builtin binary operator, it must be a user defined one. Emit
783   // a call to it.
784   Function *F = getFunction(std::string("binary") + Op);
785   assert(F && "binary operator not found!");
786 
787   Value *Ops[] = {L, R};
788   return Builder.CreateCall(F, Ops, "binop");
789 }
790 
codegen()791 Value *CallExprAST::codegen() {
792   // Look up the name in the global module table.
793   Function *CalleeF = getFunction(Callee);
794   if (!CalleeF)
795     return LogErrorV("Unknown function referenced");
796 
797   // If argument mismatch error.
798   if (CalleeF->arg_size() != Args.size())
799     return LogErrorV("Incorrect # arguments passed");
800 
801   std::vector<Value *> ArgsV;
802   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
803     ArgsV.push_back(Args[i]->codegen());
804     if (!ArgsV.back())
805       return nullptr;
806   }
807 
808   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
809 }
810 
codegen()811 Value *IfExprAST::codegen() {
812   Value *CondV = Cond->codegen();
813   if (!CondV)
814     return nullptr;
815 
816   // Convert condition to a bool by comparing equal to 0.0.
817   CondV = Builder.CreateFCmpONE(
818       CondV, ConstantFP::get(TheContext, APFloat(0.0)), "ifcond");
819 
820   Function *TheFunction = Builder.GetInsertBlock()->getParent();
821 
822   // Create blocks for the then and else cases.  Insert the 'then' block at the
823   // end of the function.
824   BasicBlock *ThenBB = BasicBlock::Create(TheContext, "then", TheFunction);
825   BasicBlock *ElseBB = BasicBlock::Create(TheContext, "else");
826   BasicBlock *MergeBB = BasicBlock::Create(TheContext, "ifcont");
827 
828   Builder.CreateCondBr(CondV, ThenBB, ElseBB);
829 
830   // Emit then value.
831   Builder.SetInsertPoint(ThenBB);
832 
833   Value *ThenV = Then->codegen();
834   if (!ThenV)
835     return nullptr;
836 
837   Builder.CreateBr(MergeBB);
838   // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
839   ThenBB = Builder.GetInsertBlock();
840 
841   // Emit else block.
842   TheFunction->getBasicBlockList().push_back(ElseBB);
843   Builder.SetInsertPoint(ElseBB);
844 
845   Value *ElseV = Else->codegen();
846   if (!ElseV)
847     return nullptr;
848 
849   Builder.CreateBr(MergeBB);
850   // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
851   ElseBB = Builder.GetInsertBlock();
852 
853   // Emit merge block.
854   TheFunction->getBasicBlockList().push_back(MergeBB);
855   Builder.SetInsertPoint(MergeBB);
856   PHINode *PN = Builder.CreatePHI(Type::getDoubleTy(TheContext), 2, "iftmp");
857 
858   PN->addIncoming(ThenV, ThenBB);
859   PN->addIncoming(ElseV, ElseBB);
860   return PN;
861 }
862 
863 // Output for-loop as:
864 //   var = alloca double
865 //   ...
866 //   start = startexpr
867 //   store start -> var
868 //   goto loop
869 // loop:
870 //   ...
871 //   bodyexpr
872 //   ...
873 // loopend:
874 //   step = stepexpr
875 //   endcond = endexpr
876 //
877 //   curvar = load var
878 //   nextvar = curvar + step
879 //   store nextvar -> var
880 //   br endcond, loop, endloop
881 // outloop:
codegen()882 Value *ForExprAST::codegen() {
883   Function *TheFunction = Builder.GetInsertBlock()->getParent();
884 
885   // Create an alloca for the variable in the entry block.
886   AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
887 
888   // Emit the start code first, without 'variable' in scope.
889   Value *StartVal = Start->codegen();
890   if (!StartVal)
891     return nullptr;
892 
893   // Store the value into the alloca.
894   Builder.CreateStore(StartVal, Alloca);
895 
896   // Make the new basic block for the loop header, inserting after current
897   // block.
898   BasicBlock *LoopBB = BasicBlock::Create(TheContext, "loop", TheFunction);
899 
900   // Insert an explicit fall through from the current block to the LoopBB.
901   Builder.CreateBr(LoopBB);
902 
903   // Start insertion in LoopBB.
904   Builder.SetInsertPoint(LoopBB);
905 
906   // Within the loop, the variable is defined equal to the PHI node.  If it
907   // shadows an existing variable, we have to restore it, so save it now.
908   AllocaInst *OldVal = NamedValues[VarName];
909   NamedValues[VarName] = Alloca;
910 
911   // Emit the body of the loop.  This, like any other expr, can change the
912   // current BB.  Note that we ignore the value computed by the body, but don't
913   // allow an error.
914   if (!Body->codegen())
915     return nullptr;
916 
917   // Emit the step value.
918   Value *StepVal = nullptr;
919   if (Step) {
920     StepVal = Step->codegen();
921     if (!StepVal)
922       return nullptr;
923   } else {
924     // If not specified, use 1.0.
925     StepVal = ConstantFP::get(TheContext, APFloat(1.0));
926   }
927 
928   // Compute the end condition.
929   Value *EndCond = End->codegen();
930   if (!EndCond)
931     return nullptr;
932 
933   // Reload, increment, and restore the alloca.  This handles the case where
934   // the body of the loop mutates the variable.
935   Value *CurVar = Builder.CreateLoad(Alloca, VarName.c_str());
936   Value *NextVar = Builder.CreateFAdd(CurVar, StepVal, "nextvar");
937   Builder.CreateStore(NextVar, Alloca);
938 
939   // Convert condition to a bool by comparing equal to 0.0.
940   EndCond = Builder.CreateFCmpONE(
941       EndCond, ConstantFP::get(TheContext, APFloat(0.0)), "loopcond");
942 
943   // Create the "after loop" block and insert it.
944   BasicBlock *AfterBB =
945       BasicBlock::Create(TheContext, "afterloop", TheFunction);
946 
947   // Insert the conditional branch into the end of LoopEndBB.
948   Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
949 
950   // Any new code will be inserted in AfterBB.
951   Builder.SetInsertPoint(AfterBB);
952 
953   // Restore the unshadowed variable.
954   if (OldVal)
955     NamedValues[VarName] = OldVal;
956   else
957     NamedValues.erase(VarName);
958 
959   // for expr always returns 0.0.
960   return Constant::getNullValue(Type::getDoubleTy(TheContext));
961 }
962 
codegen()963 Value *VarExprAST::codegen() {
964   std::vector<AllocaInst *> OldBindings;
965 
966   Function *TheFunction = Builder.GetInsertBlock()->getParent();
967 
968   // Register all variables and emit their initializer.
969   for (unsigned i = 0, e = VarNames.size(); i != e; ++i) {
970     const std::string &VarName = VarNames[i].first;
971     ExprAST *Init = VarNames[i].second.get();
972 
973     // Emit the initializer before adding the variable to scope, this prevents
974     // the initializer from referencing the variable itself, and permits stuff
975     // like this:
976     //  var a = 1 in
977     //    var a = a in ...   # refers to outer 'a'.
978     Value *InitVal;
979     if (Init) {
980       InitVal = Init->codegen();
981       if (!InitVal)
982         return nullptr;
983     } else { // If not specified, use 0.0.
984       InitVal = ConstantFP::get(TheContext, APFloat(0.0));
985     }
986 
987     AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, VarName);
988     Builder.CreateStore(InitVal, Alloca);
989 
990     // Remember the old variable binding so that we can restore the binding when
991     // we unrecurse.
992     OldBindings.push_back(NamedValues[VarName]);
993 
994     // Remember this binding.
995     NamedValues[VarName] = Alloca;
996   }
997 
998   // Codegen the body, now that all vars are in scope.
999   Value *BodyVal = Body->codegen();
1000   if (!BodyVal)
1001     return nullptr;
1002 
1003   // Pop all our variables from scope.
1004   for (unsigned i = 0, e = VarNames.size(); i != e; ++i)
1005     NamedValues[VarNames[i].first] = OldBindings[i];
1006 
1007   // Return the body computation.
1008   return BodyVal;
1009 }
1010 
codegen()1011 Function *PrototypeAST::codegen() {
1012   // Make the function type:  double(double,double) etc.
1013   std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(TheContext));
1014   FunctionType *FT =
1015       FunctionType::get(Type::getDoubleTy(TheContext), Doubles, false);
1016 
1017   Function *F =
1018       Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
1019 
1020   // Set names for all arguments.
1021   unsigned Idx = 0;
1022   for (auto &Arg : F->args())
1023     Arg.setName(Args[Idx++]);
1024 
1025   return F;
1026 }
1027 
getProto() const1028 const PrototypeAST& FunctionAST::getProto() const {
1029   return *Proto;
1030 }
1031 
getName() const1032 const std::string& FunctionAST::getName() const {
1033   return Proto->getName();
1034 }
1035 
codegen()1036 Function *FunctionAST::codegen() {
1037   // Transfer ownership of the prototype to the FunctionProtos map, but keep a
1038   // reference to it for use below.
1039   auto &P = *Proto;
1040   Function *TheFunction = getFunction(P.getName());
1041   if (!TheFunction)
1042     return nullptr;
1043 
1044   // If this is an operator, install it.
1045   if (P.isBinaryOp())
1046     BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
1047 
1048   // Create a new basic block to start insertion into.
1049   BasicBlock *BB = BasicBlock::Create(TheContext, "entry", TheFunction);
1050   Builder.SetInsertPoint(BB);
1051 
1052   // Record the function arguments in the NamedValues map.
1053   NamedValues.clear();
1054   for (auto &Arg : TheFunction->args()) {
1055     // Create an alloca for this variable.
1056     AllocaInst *Alloca = CreateEntryBlockAlloca(TheFunction, Arg.getName());
1057 
1058     // Store the initial value into the alloca.
1059     Builder.CreateStore(&Arg, Alloca);
1060 
1061     // Add arguments to variable symbol table.
1062     NamedValues[Arg.getName()] = Alloca;
1063   }
1064 
1065   if (Value *RetVal = Body->codegen()) {
1066     // Finish off the function.
1067     Builder.CreateRet(RetVal);
1068 
1069     // Validate the generated code, checking for consistency.
1070     verifyFunction(*TheFunction);
1071 
1072     return TheFunction;
1073   }
1074 
1075   // Error reading body, remove function.
1076   TheFunction->eraseFromParent();
1077 
1078   if (P.isBinaryOp())
1079     BinopPrecedence.erase(Proto->getOperatorName());
1080   return nullptr;
1081 }
1082 
1083 //===----------------------------------------------------------------------===//
1084 // Top-Level parsing and JIT Driver
1085 //===----------------------------------------------------------------------===//
1086 
InitializeModule()1087 static void InitializeModule() {
1088   // Open a new module.
1089   TheModule = llvm::make_unique<Module>("my cool jit", TheContext);
1090   TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
1091 }
1092 
1093 std::unique_ptr<llvm::Module>
irgenAndTakeOwnership(FunctionAST & FnAST,const std::string & Suffix)1094 irgenAndTakeOwnership(FunctionAST &FnAST, const std::string &Suffix) {
1095   if (auto *F = FnAST.codegen()) {
1096     F->setName(F->getName() + Suffix);
1097     auto M = std::move(TheModule);
1098     // Start a new module.
1099     InitializeModule();
1100     return M;
1101   } else
1102     report_fatal_error("Couldn't compile lazily JIT'd function");
1103 }
1104 
HandleDefinition()1105 static void HandleDefinition() {
1106   if (auto FnAST = ParseDefinition()) {
1107     FunctionProtos[FnAST->getProto().getName()] =
1108       llvm::make_unique<PrototypeAST>(FnAST->getProto());
1109     ExitOnErr(TheJIT->addFunctionAST(std::move(FnAST)));
1110   } else {
1111     // Skip token for error recovery.
1112     getNextToken();
1113   }
1114 }
1115 
HandleExtern()1116 static void HandleExtern() {
1117   if (auto ProtoAST = ParseExtern()) {
1118     if (auto *FnIR = ProtoAST->codegen()) {
1119       fprintf(stderr, "Read extern: ");
1120       FnIR->dump();
1121       FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
1122     }
1123   } else {
1124     // Skip token for error recovery.
1125     getNextToken();
1126   }
1127 }
1128 
HandleTopLevelExpression()1129 static void HandleTopLevelExpression() {
1130   // Evaluate a top-level expression into an anonymous function.
1131   if (auto FnAST = ParseTopLevelExpr()) {
1132     FunctionProtos[FnAST->getName()] =
1133       llvm::make_unique<PrototypeAST>(FnAST->getProto());
1134     if (FnAST->codegen()) {
1135       // JIT the module containing the anonymous expression, keeping a handle so
1136       // we can free it later.
1137       auto H = TheJIT->addModule(std::move(TheModule));
1138       InitializeModule();
1139 
1140       // Search the JIT for the __anon_expr symbol.
1141       auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
1142       assert(ExprSymbol && "Function not found");
1143 
1144       // Get the symbol's address and cast it to the right type (takes no
1145       // arguments, returns a double) so we can call it as a native function.
1146       double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
1147       fprintf(stderr, "Evaluated to %f\n", FP());
1148 
1149       // Delete the anonymous expression module from the JIT.
1150       TheJIT->removeModule(H);
1151     }
1152   } else {
1153     // Skip token for error recovery.
1154     getNextToken();
1155   }
1156 }
1157 
1158 /// top ::= definition | external | expression | ';'
MainLoop()1159 static void MainLoop() {
1160   while (true) {
1161     fprintf(stderr, "ready> ");
1162     switch (CurTok) {
1163     case tok_eof:
1164       return;
1165     case ';': // ignore top-level semicolons.
1166       getNextToken();
1167       break;
1168     case tok_def:
1169       HandleDefinition();
1170       break;
1171     case tok_extern:
1172       HandleExtern();
1173       break;
1174     default:
1175       HandleTopLevelExpression();
1176       break;
1177     }
1178   }
1179 }
1180 
1181 //===----------------------------------------------------------------------===//
1182 // "Library" functions that can be "extern'd" from user code.
1183 //===----------------------------------------------------------------------===//
1184 
1185 /// putchard - putchar that takes a double and returns 0.
putchard(double X)1186 extern "C" double putchard(double X) {
1187   fputc((char)X, stderr);
1188   return 0;
1189 }
1190 
1191 /// printd - printf that takes a double prints it as "%f\n", returning 0.
printd(double X)1192 extern "C" double printd(double X) {
1193   fprintf(stderr, "%f\n", X);
1194   return 0;
1195 }
1196 
1197 //===----------------------------------------------------------------------===//
1198 // Main driver code.
1199 //===----------------------------------------------------------------------===//
1200 
main()1201 int main() {
1202   InitializeNativeTarget();
1203   InitializeNativeTargetAsmPrinter();
1204   InitializeNativeTargetAsmParser();
1205 
1206   ExitOnErr.setBanner("Kaleidoscope: ");
1207 
1208   // Install standard binary operators.
1209   // 1 is lowest precedence.
1210   BinopPrecedence['='] = 2;
1211   BinopPrecedence['<'] = 10;
1212   BinopPrecedence['+'] = 20;
1213   BinopPrecedence['-'] = 20;
1214   BinopPrecedence['*'] = 40; // highest.
1215 
1216   // Prime the first token.
1217   fprintf(stderr, "ready> ");
1218   getNextToken();
1219 
1220   TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1221 
1222   InitializeModule();
1223 
1224   // Run the main "interpreter loop" now.
1225   MainLoop();
1226 
1227   return 0;
1228 }
1229