1 #include "llvm/ADT/STLExtras.h"
2 #include "llvm/Analysis/Passes.h"
3 #include "llvm/IR/IRBuilder.h"
4 #include "llvm/IR/LLVMContext.h"
5 #include "llvm/IR/LegacyPassManager.h"
6 #include "llvm/IR/Module.h"
7 #include "llvm/IR/Verifier.h"
8 #include "llvm/Support/TargetSelect.h"
9 #include "llvm/Transforms/Scalar.h"
10 #include <cctype>
11 #include <cstdio>
12 #include <map>
13 #include <string>
14 #include <vector>
15 #include "../include/KaleidoscopeJIT.h"
16
17 using namespace llvm;
18 using namespace llvm::orc;
19
20 //===----------------------------------------------------------------------===//
21 // Lexer
22 //===----------------------------------------------------------------------===//
23
24 // The lexer returns tokens [0-255] if it is an unknown character, otherwise one
25 // of these for known things.
26 enum Token {
27 tok_eof = -1,
28
29 // commands
30 tok_def = -2,
31 tok_extern = -3,
32
33 // primary
34 tok_identifier = -4,
35 tok_number = -5,
36
37 // control
38 tok_if = -6,
39 tok_then = -7,
40 tok_else = -8,
41 tok_for = -9,
42 tok_in = -10,
43
44 // operators
45 tok_binary = -11,
46 tok_unary = -12
47 };
48
49 static std::string IdentifierStr; // Filled in if tok_identifier
50 static double NumVal; // Filled in if tok_number
51
52 /// gettok - Return the next token from standard input.
gettok()53 static int gettok() {
54 static int LastChar = ' ';
55
56 // Skip any whitespace.
57 while (isspace(LastChar))
58 LastChar = getchar();
59
60 if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
61 IdentifierStr = LastChar;
62 while (isalnum((LastChar = getchar())))
63 IdentifierStr += LastChar;
64
65 if (IdentifierStr == "def")
66 return tok_def;
67 if (IdentifierStr == "extern")
68 return tok_extern;
69 if (IdentifierStr == "if")
70 return tok_if;
71 if (IdentifierStr == "then")
72 return tok_then;
73 if (IdentifierStr == "else")
74 return tok_else;
75 if (IdentifierStr == "for")
76 return tok_for;
77 if (IdentifierStr == "in")
78 return tok_in;
79 if (IdentifierStr == "binary")
80 return tok_binary;
81 if (IdentifierStr == "unary")
82 return tok_unary;
83 return tok_identifier;
84 }
85
86 if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
87 std::string NumStr;
88 do {
89 NumStr += LastChar;
90 LastChar = getchar();
91 } while (isdigit(LastChar) || LastChar == '.');
92
93 NumVal = strtod(NumStr.c_str(), nullptr);
94 return tok_number;
95 }
96
97 if (LastChar == '#') {
98 // Comment until end of line.
99 do
100 LastChar = getchar();
101 while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
102
103 if (LastChar != EOF)
104 return gettok();
105 }
106
107 // Check for end of file. Don't eat the EOF.
108 if (LastChar == EOF)
109 return tok_eof;
110
111 // Otherwise, just return the character as its ascii value.
112 int ThisChar = LastChar;
113 LastChar = getchar();
114 return ThisChar;
115 }
116
117 //===----------------------------------------------------------------------===//
118 // Abstract Syntax Tree (aka Parse Tree)
119 //===----------------------------------------------------------------------===//
120 namespace {
121 /// ExprAST - Base class for all expression nodes.
122 class ExprAST {
123 public:
~ExprAST()124 virtual ~ExprAST() {}
125 virtual Value *codegen() = 0;
126 };
127
128 /// NumberExprAST - Expression class for numeric literals like "1.0".
129 class NumberExprAST : public ExprAST {
130 double Val;
131
132 public:
NumberExprAST(double Val)133 NumberExprAST(double Val) : Val(Val) {}
134 Value *codegen() override;
135 };
136
137 /// VariableExprAST - Expression class for referencing a variable, like "a".
138 class VariableExprAST : public ExprAST {
139 std::string Name;
140
141 public:
VariableExprAST(const std::string & Name)142 VariableExprAST(const std::string &Name) : Name(Name) {}
143 Value *codegen() override;
144 };
145
146 /// UnaryExprAST - Expression class for a unary operator.
147 class UnaryExprAST : public ExprAST {
148 char Opcode;
149 std::unique_ptr<ExprAST> Operand;
150
151 public:
UnaryExprAST(char Opcode,std::unique_ptr<ExprAST> Operand)152 UnaryExprAST(char Opcode, std::unique_ptr<ExprAST> Operand)
153 : Opcode(Opcode), Operand(std::move(Operand)) {}
154 Value *codegen() override;
155 };
156
157 /// BinaryExprAST - Expression class for a binary operator.
158 class BinaryExprAST : public ExprAST {
159 char Op;
160 std::unique_ptr<ExprAST> LHS, RHS;
161
162 public:
BinaryExprAST(char Op,std::unique_ptr<ExprAST> LHS,std::unique_ptr<ExprAST> RHS)163 BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
164 std::unique_ptr<ExprAST> RHS)
165 : Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
166 Value *codegen() override;
167 };
168
169 /// CallExprAST - Expression class for function calls.
170 class CallExprAST : public ExprAST {
171 std::string Callee;
172 std::vector<std::unique_ptr<ExprAST>> Args;
173
174 public:
CallExprAST(const std::string & Callee,std::vector<std::unique_ptr<ExprAST>> Args)175 CallExprAST(const std::string &Callee,
176 std::vector<std::unique_ptr<ExprAST>> Args)
177 : Callee(Callee), Args(std::move(Args)) {}
178 Value *codegen() override;
179 };
180
181 /// IfExprAST - Expression class for if/then/else.
182 class IfExprAST : public ExprAST {
183 std::unique_ptr<ExprAST> Cond, Then, Else;
184
185 public:
IfExprAST(std::unique_ptr<ExprAST> Cond,std::unique_ptr<ExprAST> Then,std::unique_ptr<ExprAST> Else)186 IfExprAST(std::unique_ptr<ExprAST> Cond, std::unique_ptr<ExprAST> Then,
187 std::unique_ptr<ExprAST> Else)
188 : Cond(std::move(Cond)), Then(std::move(Then)), Else(std::move(Else)) {}
189 Value *codegen() override;
190 };
191
192 /// ForExprAST - Expression class for for/in.
193 class ForExprAST : public ExprAST {
194 std::string VarName;
195 std::unique_ptr<ExprAST> Start, End, Step, Body;
196
197 public:
ForExprAST(const std::string & VarName,std::unique_ptr<ExprAST> Start,std::unique_ptr<ExprAST> End,std::unique_ptr<ExprAST> Step,std::unique_ptr<ExprAST> Body)198 ForExprAST(const std::string &VarName, std::unique_ptr<ExprAST> Start,
199 std::unique_ptr<ExprAST> End, std::unique_ptr<ExprAST> Step,
200 std::unique_ptr<ExprAST> Body)
201 : VarName(VarName), Start(std::move(Start)), End(std::move(End)),
202 Step(std::move(Step)), Body(std::move(Body)) {}
203 Value *codegen() override;
204 };
205
206 /// PrototypeAST - This class represents the "prototype" for a function,
207 /// which captures its name, and its argument names (thus implicitly the number
208 /// of arguments the function takes), as well as if it is an operator.
209 class PrototypeAST {
210 std::string Name;
211 std::vector<std::string> Args;
212 bool IsOperator;
213 unsigned Precedence; // Precedence if a binary op.
214
215 public:
PrototypeAST(const std::string & Name,std::vector<std::string> Args,bool IsOperator=false,unsigned Prec=0)216 PrototypeAST(const std::string &Name, std::vector<std::string> Args,
217 bool IsOperator = false, unsigned Prec = 0)
218 : Name(Name), Args(std::move(Args)), IsOperator(IsOperator),
219 Precedence(Prec) {}
220 Function *codegen();
getName() const221 const std::string &getName() const { return Name; }
222
isUnaryOp() const223 bool isUnaryOp() const { return IsOperator && Args.size() == 1; }
isBinaryOp() const224 bool isBinaryOp() const { return IsOperator && Args.size() == 2; }
225
getOperatorName() const226 char getOperatorName() const {
227 assert(isUnaryOp() || isBinaryOp());
228 return Name[Name.size() - 1];
229 }
230
getBinaryPrecedence() const231 unsigned getBinaryPrecedence() const { return Precedence; }
232 };
233
234 /// FunctionAST - This class represents a function definition itself.
235 class FunctionAST {
236 std::unique_ptr<PrototypeAST> Proto;
237 std::unique_ptr<ExprAST> Body;
238
239 public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,std::unique_ptr<ExprAST> Body)240 FunctionAST(std::unique_ptr<PrototypeAST> Proto,
241 std::unique_ptr<ExprAST> Body)
242 : Proto(std::move(Proto)), Body(std::move(Body)) {}
243 Function *codegen();
244 };
245 } // end anonymous namespace
246
247 //===----------------------------------------------------------------------===//
248 // Parser
249 //===----------------------------------------------------------------------===//
250
251 /// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
252 /// token the parser is looking at. getNextToken reads another token from the
253 /// lexer and updates CurTok with its results.
254 static int CurTok;
getNextToken()255 static int getNextToken() { return CurTok = gettok(); }
256
257 /// BinopPrecedence - This holds the precedence for each binary operator that is
258 /// defined.
259 static std::map<char, int> BinopPrecedence;
260
261 /// GetTokPrecedence - Get the precedence of the pending binary operator token.
GetTokPrecedence()262 static int GetTokPrecedence() {
263 if (!isascii(CurTok))
264 return -1;
265
266 // Make sure it's a declared binop.
267 int TokPrec = BinopPrecedence[CurTok];
268 if (TokPrec <= 0)
269 return -1;
270 return TokPrec;
271 }
272
273 /// Error* - These are little helper functions for error handling.
Error(const char * Str)274 std::unique_ptr<ExprAST> Error(const char *Str) {
275 fprintf(stderr, "Error: %s\n", Str);
276 return nullptr;
277 }
278
ErrorP(const char * Str)279 std::unique_ptr<PrototypeAST> ErrorP(const char *Str) {
280 Error(Str);
281 return nullptr;
282 }
283
284 static std::unique_ptr<ExprAST> ParseExpression();
285
286 /// numberexpr ::= number
ParseNumberExpr()287 static std::unique_ptr<ExprAST> ParseNumberExpr() {
288 auto Result = llvm::make_unique<NumberExprAST>(NumVal);
289 getNextToken(); // consume the number
290 return std::move(Result);
291 }
292
293 /// parenexpr ::= '(' expression ')'
ParseParenExpr()294 static std::unique_ptr<ExprAST> ParseParenExpr() {
295 getNextToken(); // eat (.
296 auto V = ParseExpression();
297 if (!V)
298 return nullptr;
299
300 if (CurTok != ')')
301 return Error("expected ')'");
302 getNextToken(); // eat ).
303 return V;
304 }
305
306 /// identifierexpr
307 /// ::= identifier
308 /// ::= identifier '(' expression* ')'
ParseIdentifierExpr()309 static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
310 std::string IdName = IdentifierStr;
311
312 getNextToken(); // eat identifier.
313
314 if (CurTok != '(') // Simple variable ref.
315 return llvm::make_unique<VariableExprAST>(IdName);
316
317 // Call.
318 getNextToken(); // eat (
319 std::vector<std::unique_ptr<ExprAST>> Args;
320 if (CurTok != ')') {
321 while (1) {
322 if (auto Arg = ParseExpression())
323 Args.push_back(std::move(Arg));
324 else
325 return nullptr;
326
327 if (CurTok == ')')
328 break;
329
330 if (CurTok != ',')
331 return Error("Expected ')' or ',' in argument list");
332 getNextToken();
333 }
334 }
335
336 // Eat the ')'.
337 getNextToken();
338
339 return llvm::make_unique<CallExprAST>(IdName, std::move(Args));
340 }
341
342 /// ifexpr ::= 'if' expression 'then' expression 'else' expression
ParseIfExpr()343 static std::unique_ptr<ExprAST> ParseIfExpr() {
344 getNextToken(); // eat the if.
345
346 // condition.
347 auto Cond = ParseExpression();
348 if (!Cond)
349 return nullptr;
350
351 if (CurTok != tok_then)
352 return Error("expected then");
353 getNextToken(); // eat the then
354
355 auto Then = ParseExpression();
356 if (!Then)
357 return nullptr;
358
359 if (CurTok != tok_else)
360 return Error("expected else");
361
362 getNextToken();
363
364 auto Else = ParseExpression();
365 if (!Else)
366 return nullptr;
367
368 return llvm::make_unique<IfExprAST>(std::move(Cond), std::move(Then),
369 std::move(Else));
370 }
371
372 /// forexpr ::= 'for' identifier '=' expr ',' expr (',' expr)? 'in' expression
ParseForExpr()373 static std::unique_ptr<ExprAST> ParseForExpr() {
374 getNextToken(); // eat the for.
375
376 if (CurTok != tok_identifier)
377 return Error("expected identifier after for");
378
379 std::string IdName = IdentifierStr;
380 getNextToken(); // eat identifier.
381
382 if (CurTok != '=')
383 return Error("expected '=' after for");
384 getNextToken(); // eat '='.
385
386 auto Start = ParseExpression();
387 if (!Start)
388 return nullptr;
389 if (CurTok != ',')
390 return Error("expected ',' after for start value");
391 getNextToken();
392
393 auto End = ParseExpression();
394 if (!End)
395 return nullptr;
396
397 // The step value is optional.
398 std::unique_ptr<ExprAST> Step;
399 if (CurTok == ',') {
400 getNextToken();
401 Step = ParseExpression();
402 if (!Step)
403 return nullptr;
404 }
405
406 if (CurTok != tok_in)
407 return Error("expected 'in' after for");
408 getNextToken(); // eat 'in'.
409
410 auto Body = ParseExpression();
411 if (!Body)
412 return nullptr;
413
414 return llvm::make_unique<ForExprAST>(IdName, std::move(Start), std::move(End),
415 std::move(Step), std::move(Body));
416 }
417
418 /// primary
419 /// ::= identifierexpr
420 /// ::= numberexpr
421 /// ::= parenexpr
422 /// ::= ifexpr
423 /// ::= forexpr
ParsePrimary()424 static std::unique_ptr<ExprAST> ParsePrimary() {
425 switch (CurTok) {
426 default:
427 return Error("unknown token when expecting an expression");
428 case tok_identifier:
429 return ParseIdentifierExpr();
430 case tok_number:
431 return ParseNumberExpr();
432 case '(':
433 return ParseParenExpr();
434 case tok_if:
435 return ParseIfExpr();
436 case tok_for:
437 return ParseForExpr();
438 }
439 }
440
441 /// unary
442 /// ::= primary
443 /// ::= '!' unary
ParseUnary()444 static std::unique_ptr<ExprAST> ParseUnary() {
445 // If the current token is not an operator, it must be a primary expr.
446 if (!isascii(CurTok) || CurTok == '(' || CurTok == ',')
447 return ParsePrimary();
448
449 // If this is a unary operator, read it.
450 int Opc = CurTok;
451 getNextToken();
452 if (auto Operand = ParseUnary())
453 return llvm::make_unique<UnaryExprAST>(Opc, std::move(Operand));
454 return nullptr;
455 }
456
457 /// binoprhs
458 /// ::= ('+' unary)*
ParseBinOpRHS(int ExprPrec,std::unique_ptr<ExprAST> LHS)459 static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
460 std::unique_ptr<ExprAST> LHS) {
461 // If this is a binop, find its precedence.
462 while (1) {
463 int TokPrec = GetTokPrecedence();
464
465 // If this is a binop that binds at least as tightly as the current binop,
466 // consume it, otherwise we are done.
467 if (TokPrec < ExprPrec)
468 return LHS;
469
470 // Okay, we know this is a binop.
471 int BinOp = CurTok;
472 getNextToken(); // eat binop
473
474 // Parse the unary expression after the binary operator.
475 auto RHS = ParseUnary();
476 if (!RHS)
477 return nullptr;
478
479 // If BinOp binds less tightly with RHS than the operator after RHS, let
480 // the pending operator take RHS as its LHS.
481 int NextPrec = GetTokPrecedence();
482 if (TokPrec < NextPrec) {
483 RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
484 if (!RHS)
485 return nullptr;
486 }
487
488 // Merge LHS/RHS.
489 LHS =
490 llvm::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
491 }
492 }
493
494 /// expression
495 /// ::= unary binoprhs
496 ///
ParseExpression()497 static std::unique_ptr<ExprAST> ParseExpression() {
498 auto LHS = ParseUnary();
499 if (!LHS)
500 return nullptr;
501
502 return ParseBinOpRHS(0, std::move(LHS));
503 }
504
505 /// prototype
506 /// ::= id '(' id* ')'
507 /// ::= binary LETTER number? (id, id)
508 /// ::= unary LETTER (id)
ParsePrototype()509 static std::unique_ptr<PrototypeAST> ParsePrototype() {
510 std::string FnName;
511
512 unsigned Kind = 0; // 0 = identifier, 1 = unary, 2 = binary.
513 unsigned BinaryPrecedence = 30;
514
515 switch (CurTok) {
516 default:
517 return ErrorP("Expected function name in prototype");
518 case tok_identifier:
519 FnName = IdentifierStr;
520 Kind = 0;
521 getNextToken();
522 break;
523 case tok_unary:
524 getNextToken();
525 if (!isascii(CurTok))
526 return ErrorP("Expected unary operator");
527 FnName = "unary";
528 FnName += (char)CurTok;
529 Kind = 1;
530 getNextToken();
531 break;
532 case tok_binary:
533 getNextToken();
534 if (!isascii(CurTok))
535 return ErrorP("Expected binary operator");
536 FnName = "binary";
537 FnName += (char)CurTok;
538 Kind = 2;
539 getNextToken();
540
541 // Read the precedence if present.
542 if (CurTok == tok_number) {
543 if (NumVal < 1 || NumVal > 100)
544 return ErrorP("Invalid precedecnce: must be 1..100");
545 BinaryPrecedence = (unsigned)NumVal;
546 getNextToken();
547 }
548 break;
549 }
550
551 if (CurTok != '(')
552 return ErrorP("Expected '(' in prototype");
553
554 std::vector<std::string> ArgNames;
555 while (getNextToken() == tok_identifier)
556 ArgNames.push_back(IdentifierStr);
557 if (CurTok != ')')
558 return ErrorP("Expected ')' in prototype");
559
560 // success.
561 getNextToken(); // eat ')'.
562
563 // Verify right number of names for operator.
564 if (Kind && ArgNames.size() != Kind)
565 return ErrorP("Invalid number of operands for operator");
566
567 return llvm::make_unique<PrototypeAST>(FnName, ArgNames, Kind != 0,
568 BinaryPrecedence);
569 }
570
571 /// definition ::= 'def' prototype expression
ParseDefinition()572 static std::unique_ptr<FunctionAST> ParseDefinition() {
573 getNextToken(); // eat def.
574 auto Proto = ParsePrototype();
575 if (!Proto)
576 return nullptr;
577
578 if (auto E = ParseExpression())
579 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
580 return nullptr;
581 }
582
583 /// toplevelexpr ::= expression
ParseTopLevelExpr()584 static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
585 if (auto E = ParseExpression()) {
586 // Make an anonymous proto.
587 auto Proto = llvm::make_unique<PrototypeAST>("__anon_expr",
588 std::vector<std::string>());
589 return llvm::make_unique<FunctionAST>(std::move(Proto), std::move(E));
590 }
591 return nullptr;
592 }
593
594 /// external ::= 'extern' prototype
ParseExtern()595 static std::unique_ptr<PrototypeAST> ParseExtern() {
596 getNextToken(); // eat extern.
597 return ParsePrototype();
598 }
599
600 //===----------------------------------------------------------------------===//
601 // Code Generation
602 //===----------------------------------------------------------------------===//
603
604 static std::unique_ptr<Module> TheModule;
605 static IRBuilder<> Builder(getGlobalContext());
606 static std::map<std::string, Value *> NamedValues;
607 static std::unique_ptr<legacy::FunctionPassManager> TheFPM;
608 static std::unique_ptr<KaleidoscopeJIT> TheJIT;
609 static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
610
ErrorV(const char * Str)611 Value *ErrorV(const char *Str) {
612 Error(Str);
613 return nullptr;
614 }
615
getFunction(std::string Name)616 Function *getFunction(std::string Name) {
617 // First, see if the function has already been added to the current module.
618 if (auto *F = TheModule->getFunction(Name))
619 return F;
620
621 // If not, check whether we can codegen the declaration from some existing
622 // prototype.
623 auto FI = FunctionProtos.find(Name);
624 if (FI != FunctionProtos.end())
625 return FI->second->codegen();
626
627 // If no existing prototype exists, return null.
628 return nullptr;
629 }
630
codegen()631 Value *NumberExprAST::codegen() {
632 return ConstantFP::get(getGlobalContext(), APFloat(Val));
633 }
634
codegen()635 Value *VariableExprAST::codegen() {
636 // Look this variable up in the function.
637 Value *V = NamedValues[Name];
638 if (!V)
639 return ErrorV("Unknown variable name");
640 return V;
641 }
642
codegen()643 Value *UnaryExprAST::codegen() {
644 Value *OperandV = Operand->codegen();
645 if (!OperandV)
646 return nullptr;
647
648 Function *F = getFunction(std::string("unary") + Opcode);
649 if (!F)
650 return ErrorV("Unknown unary operator");
651
652 return Builder.CreateCall(F, OperandV, "unop");
653 }
654
codegen()655 Value *BinaryExprAST::codegen() {
656 Value *L = LHS->codegen();
657 Value *R = RHS->codegen();
658 if (!L || !R)
659 return nullptr;
660
661 switch (Op) {
662 case '+':
663 return Builder.CreateFAdd(L, R, "addtmp");
664 case '-':
665 return Builder.CreateFSub(L, R, "subtmp");
666 case '*':
667 return Builder.CreateFMul(L, R, "multmp");
668 case '<':
669 L = Builder.CreateFCmpULT(L, R, "cmptmp");
670 // Convert bool 0/1 to double 0.0 or 1.0
671 return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
672 "booltmp");
673 default:
674 break;
675 }
676
677 // If it wasn't a builtin binary operator, it must be a user defined one. Emit
678 // a call to it.
679 Function *F = getFunction(std::string("binary") + Op);
680 assert(F && "binary operator not found!");
681
682 Value *Ops[] = {L, R};
683 return Builder.CreateCall(F, Ops, "binop");
684 }
685
codegen()686 Value *CallExprAST::codegen() {
687 // Look up the name in the global module table.
688 Function *CalleeF = getFunction(Callee);
689 if (!CalleeF)
690 return ErrorV("Unknown function referenced");
691
692 // If argument mismatch error.
693 if (CalleeF->arg_size() != Args.size())
694 return ErrorV("Incorrect # arguments passed");
695
696 std::vector<Value *> ArgsV;
697 for (unsigned i = 0, e = Args.size(); i != e; ++i) {
698 ArgsV.push_back(Args[i]->codegen());
699 if (!ArgsV.back())
700 return nullptr;
701 }
702
703 return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
704 }
705
codegen()706 Value *IfExprAST::codegen() {
707 Value *CondV = Cond->codegen();
708 if (!CondV)
709 return nullptr;
710
711 // Convert condition to a bool by comparing equal to 0.0.
712 CondV = Builder.CreateFCmpONE(
713 CondV, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "ifcond");
714
715 Function *TheFunction = Builder.GetInsertBlock()->getParent();
716
717 // Create blocks for the then and else cases. Insert the 'then' block at the
718 // end of the function.
719 BasicBlock *ThenBB =
720 BasicBlock::Create(getGlobalContext(), "then", TheFunction);
721 BasicBlock *ElseBB = BasicBlock::Create(getGlobalContext(), "else");
722 BasicBlock *MergeBB = BasicBlock::Create(getGlobalContext(), "ifcont");
723
724 Builder.CreateCondBr(CondV, ThenBB, ElseBB);
725
726 // Emit then value.
727 Builder.SetInsertPoint(ThenBB);
728
729 Value *ThenV = Then->codegen();
730 if (!ThenV)
731 return nullptr;
732
733 Builder.CreateBr(MergeBB);
734 // Codegen of 'Then' can change the current block, update ThenBB for the PHI.
735 ThenBB = Builder.GetInsertBlock();
736
737 // Emit else block.
738 TheFunction->getBasicBlockList().push_back(ElseBB);
739 Builder.SetInsertPoint(ElseBB);
740
741 Value *ElseV = Else->codegen();
742 if (!ElseV)
743 return nullptr;
744
745 Builder.CreateBr(MergeBB);
746 // Codegen of 'Else' can change the current block, update ElseBB for the PHI.
747 ElseBB = Builder.GetInsertBlock();
748
749 // Emit merge block.
750 TheFunction->getBasicBlockList().push_back(MergeBB);
751 Builder.SetInsertPoint(MergeBB);
752 PHINode *PN =
753 Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()), 2, "iftmp");
754
755 PN->addIncoming(ThenV, ThenBB);
756 PN->addIncoming(ElseV, ElseBB);
757 return PN;
758 }
759
760 // Output for-loop as:
761 // ...
762 // start = startexpr
763 // goto loop
764 // loop:
765 // variable = phi [start, loopheader], [nextvariable, loopend]
766 // ...
767 // bodyexpr
768 // ...
769 // loopend:
770 // step = stepexpr
771 // nextvariable = variable + step
772 // endcond = endexpr
773 // br endcond, loop, endloop
774 // outloop:
codegen()775 Value *ForExprAST::codegen() {
776 // Emit the start code first, without 'variable' in scope.
777 Value *StartVal = Start->codegen();
778 if (!StartVal)
779 return nullptr;
780
781 // Make the new basic block for the loop header, inserting after current
782 // block.
783 Function *TheFunction = Builder.GetInsertBlock()->getParent();
784 BasicBlock *PreheaderBB = Builder.GetInsertBlock();
785 BasicBlock *LoopBB =
786 BasicBlock::Create(getGlobalContext(), "loop", TheFunction);
787
788 // Insert an explicit fall through from the current block to the LoopBB.
789 Builder.CreateBr(LoopBB);
790
791 // Start insertion in LoopBB.
792 Builder.SetInsertPoint(LoopBB);
793
794 // Start the PHI node with an entry for Start.
795 PHINode *Variable = Builder.CreatePHI(Type::getDoubleTy(getGlobalContext()),
796 2, VarName.c_str());
797 Variable->addIncoming(StartVal, PreheaderBB);
798
799 // Within the loop, the variable is defined equal to the PHI node. If it
800 // shadows an existing variable, we have to restore it, so save it now.
801 Value *OldVal = NamedValues[VarName];
802 NamedValues[VarName] = Variable;
803
804 // Emit the body of the loop. This, like any other expr, can change the
805 // current BB. Note that we ignore the value computed by the body, but don't
806 // allow an error.
807 if (!Body->codegen())
808 return nullptr;
809
810 // Emit the step value.
811 Value *StepVal = nullptr;
812 if (Step) {
813 StepVal = Step->codegen();
814 if (!StepVal)
815 return nullptr;
816 } else {
817 // If not specified, use 1.0.
818 StepVal = ConstantFP::get(getGlobalContext(), APFloat(1.0));
819 }
820
821 Value *NextVar = Builder.CreateFAdd(Variable, StepVal, "nextvar");
822
823 // Compute the end condition.
824 Value *EndCond = End->codegen();
825 if (!EndCond)
826 return nullptr;
827
828 // Convert condition to a bool by comparing equal to 0.0.
829 EndCond = Builder.CreateFCmpONE(
830 EndCond, ConstantFP::get(getGlobalContext(), APFloat(0.0)), "loopcond");
831
832 // Create the "after loop" block and insert it.
833 BasicBlock *LoopEndBB = Builder.GetInsertBlock();
834 BasicBlock *AfterBB =
835 BasicBlock::Create(getGlobalContext(), "afterloop", TheFunction);
836
837 // Insert the conditional branch into the end of LoopEndBB.
838 Builder.CreateCondBr(EndCond, LoopBB, AfterBB);
839
840 // Any new code will be inserted in AfterBB.
841 Builder.SetInsertPoint(AfterBB);
842
843 // Add a new entry to the PHI node for the backedge.
844 Variable->addIncoming(NextVar, LoopEndBB);
845
846 // Restore the unshadowed variable.
847 if (OldVal)
848 NamedValues[VarName] = OldVal;
849 else
850 NamedValues.erase(VarName);
851
852 // for expr always returns 0.0.
853 return Constant::getNullValue(Type::getDoubleTy(getGlobalContext()));
854 }
855
codegen()856 Function *PrototypeAST::codegen() {
857 // Make the function type: double(double,double) etc.
858 std::vector<Type *> Doubles(Args.size(),
859 Type::getDoubleTy(getGlobalContext()));
860 FunctionType *FT =
861 FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
862
863 Function *F =
864 Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
865
866 // Set names for all arguments.
867 unsigned Idx = 0;
868 for (auto &Arg : F->args())
869 Arg.setName(Args[Idx++]);
870
871 return F;
872 }
873
codegen()874 Function *FunctionAST::codegen() {
875 // Transfer ownership of the prototype to the FunctionProtos map, but keep a
876 // reference to it for use below.
877 auto &P = *Proto;
878 FunctionProtos[Proto->getName()] = std::move(Proto);
879 Function *TheFunction = getFunction(P.getName());
880 if (!TheFunction)
881 return nullptr;
882
883 // If this is an operator, install it.
884 if (P.isBinaryOp())
885 BinopPrecedence[P.getOperatorName()] = P.getBinaryPrecedence();
886
887 // Create a new basic block to start insertion into.
888 BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
889 Builder.SetInsertPoint(BB);
890
891 // Record the function arguments in the NamedValues map.
892 NamedValues.clear();
893 for (auto &Arg : TheFunction->args())
894 NamedValues[Arg.getName()] = &Arg;
895
896 if (Value *RetVal = Body->codegen()) {
897 // Finish off the function.
898 Builder.CreateRet(RetVal);
899
900 // Validate the generated code, checking for consistency.
901 verifyFunction(*TheFunction);
902
903 // Run the optimizer on the function.
904 TheFPM->run(*TheFunction);
905
906 return TheFunction;
907 }
908
909 // Error reading body, remove function.
910 TheFunction->eraseFromParent();
911
912 if (P.isBinaryOp())
913 BinopPrecedence.erase(Proto->getOperatorName());
914 return nullptr;
915 }
916
917 //===----------------------------------------------------------------------===//
918 // Top-Level parsing and JIT Driver
919 //===----------------------------------------------------------------------===//
920
InitializeModuleAndPassManager()921 static void InitializeModuleAndPassManager() {
922 // Open a new module.
923 TheModule = llvm::make_unique<Module>("my cool jit", getGlobalContext());
924 TheModule->setDataLayout(TheJIT->getTargetMachine().createDataLayout());
925
926 // Create a new pass manager attached to it.
927 TheFPM = llvm::make_unique<legacy::FunctionPassManager>(TheModule.get());
928
929 // Do simple "peephole" optimizations and bit-twiddling optzns.
930 TheFPM->add(createInstructionCombiningPass());
931 // Reassociate expressions.
932 TheFPM->add(createReassociatePass());
933 // Eliminate Common SubExpressions.
934 TheFPM->add(createGVNPass());
935 // Simplify the control flow graph (deleting unreachable blocks, etc).
936 TheFPM->add(createCFGSimplificationPass());
937
938 TheFPM->doInitialization();
939 }
940
HandleDefinition()941 static void HandleDefinition() {
942 if (auto FnAST = ParseDefinition()) {
943 if (auto *FnIR = FnAST->codegen()) {
944 fprintf(stderr, "Read function definition:");
945 FnIR->dump();
946 TheJIT->addModule(std::move(TheModule));
947 InitializeModuleAndPassManager();
948 }
949 } else {
950 // Skip token for error recovery.
951 getNextToken();
952 }
953 }
954
HandleExtern()955 static void HandleExtern() {
956 if (auto ProtoAST = ParseExtern()) {
957 if (auto *FnIR = ProtoAST->codegen()) {
958 fprintf(stderr, "Read extern: ");
959 FnIR->dump();
960 FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
961 }
962 } else {
963 // Skip token for error recovery.
964 getNextToken();
965 }
966 }
967
HandleTopLevelExpression()968 static void HandleTopLevelExpression() {
969 // Evaluate a top-level expression into an anonymous function.
970 if (auto FnAST = ParseTopLevelExpr()) {
971 if (FnAST->codegen()) {
972
973 // JIT the module containing the anonymous expression, keeping a handle so
974 // we can free it later.
975 auto H = TheJIT->addModule(std::move(TheModule));
976 InitializeModuleAndPassManager();
977
978 // Search the JIT for the __anon_expr symbol.
979 auto ExprSymbol = TheJIT->findSymbol("__anon_expr");
980 assert(ExprSymbol && "Function not found");
981
982 // Get the symbol's address and cast it to the right type (takes no
983 // arguments, returns a double) so we can call it as a native function.
984 double (*FP)() = (double (*)())(intptr_t)ExprSymbol.getAddress();
985 fprintf(stderr, "Evaluated to %f\n", FP());
986
987 // Delete the anonymous expression module from the JIT.
988 TheJIT->removeModule(H);
989 }
990 } else {
991 // Skip token for error recovery.
992 getNextToken();
993 }
994 }
995
996 /// top ::= definition | external | expression | ';'
MainLoop()997 static void MainLoop() {
998 while (1) {
999 fprintf(stderr, "ready> ");
1000 switch (CurTok) {
1001 case tok_eof:
1002 return;
1003 case ';': // ignore top-level semicolons.
1004 getNextToken();
1005 break;
1006 case tok_def:
1007 HandleDefinition();
1008 break;
1009 case tok_extern:
1010 HandleExtern();
1011 break;
1012 default:
1013 HandleTopLevelExpression();
1014 break;
1015 }
1016 }
1017 }
1018
1019 //===----------------------------------------------------------------------===//
1020 // "Library" functions that can be "extern'd" from user code.
1021 //===----------------------------------------------------------------------===//
1022
1023 /// putchard - putchar that takes a double and returns 0.
putchard(double X)1024 extern "C" double putchard(double X) {
1025 fputc((char)X, stderr);
1026 return 0;
1027 }
1028
1029 /// printd - printf that takes a double prints it as "%f\n", returning 0.
printd(double X)1030 extern "C" double printd(double X) {
1031 fprintf(stderr, "%f\n", X);
1032 return 0;
1033 }
1034
1035 //===----------------------------------------------------------------------===//
1036 // Main driver code.
1037 //===----------------------------------------------------------------------===//
1038
main()1039 int main() {
1040 InitializeNativeTarget();
1041 InitializeNativeTargetAsmPrinter();
1042 InitializeNativeTargetAsmParser();
1043
1044 // Install standard binary operators.
1045 // 1 is lowest precedence.
1046 BinopPrecedence['<'] = 10;
1047 BinopPrecedence['+'] = 20;
1048 BinopPrecedence['-'] = 20;
1049 BinopPrecedence['*'] = 40; // highest.
1050
1051 // Prime the first token.
1052 fprintf(stderr, "ready> ");
1053 getNextToken();
1054
1055 TheJIT = llvm::make_unique<KaleidoscopeJIT>();
1056
1057 InitializeModuleAndPassManager();
1058
1059 // Run the main "interpreter loop" now.
1060 MainLoop();
1061
1062 return 0;
1063 }
1064