4. 万花筒:添加 JIT 和优化器支持¶
4.1. 第 4 章 简介¶
欢迎来到“使用 LLVM 实现语言”教程的第 4 章。第 1-3 章描述了一个简单语言的实现,并添加了生成 LLVM IR 的支持。本章介绍两种新技术:为您的语言添加优化器支持,以及添加 JIT 编译器支持。这些新增功能将演示如何为 Kaleidoscope 语言获得良好且高效的代码。
4.2. 简单的常量折叠¶
我们第 3 章的演示优雅且易于扩展。不幸的是,它不会产生非常好的代码。但是,IRBuilder 在编译简单代码时确实为我们提供了明显的优化。
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
ret double %addtmp
}
此代码不是解析输入构建的 AST 的逐字转录。那将是
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 2.000000e+00, 1.000000e+00
%addtmp1 = fadd double %addtmp, %x
ret double %addtmp1
}
如上所示,常量折叠尤其是一种非常常见且非常重要的优化:以至于许多语言实现者在其 AST 表示中实现了常量折叠支持。
使用 LLVM,您不需要 AST 中的此支持。由于所有构建 LLVM IR 的调用都通过 LLVM IR 构建器进行,因此构建器本身会在您调用它时检查是否存在常量折叠机会。如果是,它只需执行常量折叠并返回常量,而不是创建指令。
好吧,这很容易 :)。在实践中,我们建议在生成类似这样的代码时始终使用IRBuilder
。它在使用时没有“语法开销”(您不必用常量检查来使您的编译器变得丑陋),并且在某些情况下可以显着减少生成的 LLVM IR 量(特别是对于使用宏预处理器或大量使用常量的语言)。
另一方面,IRBuilder
受到这样一个事实的限制,即它在构建时将其所有分析内联到代码中。如果您采用稍微复杂一点的示例
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
%addtmp1 = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp1
ret double %multmp
}
在这种情况下,乘法的 LHS 和 RHS 是相同的值。我们真的希望看到它生成“tmp = x+3; result = tmp*tmp;
”而不是计算两次“x+3
”。
不幸的是,无论进行多少本地分析都无法检测和纠正此问题。这需要两个转换:表达式的重新关联(使加法在词法上相同)和公共子表达式消除 (CSE) 以删除冗余的加法指令。幸运的是,LLVM 提供了广泛的优化,您可以以“Pass”的形式使用它们。
4.3. LLVM 优化 Pass¶
LLVM 提供了许多优化 Pass,它们执行许多不同类型的事情,并且具有不同的权衡。与其他系统不同,LLVM 并不坚持错误的观念,即一套优化适合所有语言和所有情况。LLVM 允许编译器实现者完全决定使用哪些优化,以什么顺序,以及在什么情况下使用。
作为一个具体的例子,LLVM 支持“整个模块”Pass,它们查看尽可能大的代码体(通常是整个文件,但如果在链接时运行,这可能是整个程序的很大一部分)。它还支持并包含“每个函数”Pass,这些 Pass 只在一个函数上运行,而不查看其他函数。有关 Pass 及其运行方式的更多信息,请参阅如何编写 Pass文档和LLVM Pass 列表。
对于 Kaleidoscope,我们目前正在动态生成函数,一次一个,当用户输入它们时。在这种情况下,我们并没有追求极致的优化体验,但我们也希望尽可能地捕捉到简单快速的内容。因此,我们选择在用户输入函数时运行一些每个函数的优化。如果我们想制作一个“静态 Kaleidoscope 编译器”,我们将使用我们现在拥有的完全相同的代码,只是我们将推迟运行优化器,直到整个文件都被解析。
除了函数和模块 Pass 之间的区别外,Pass 可以分为转换 Pass 和分析 Pass。转换 Pass 会修改 IR,而分析 Pass 会计算其他 Pass 可以使用的信息。为了添加转换 Pass,它依赖的所有分析 Pass 必须提前注册。
为了让每个函数的优化开始运行,我们需要设置一个FunctionPassManager来保存和组织我们想要运行的 LLVM 优化。一旦我们有了它,我们就可以添加一组要运行的优化。我们需要为我们要优化的每个模块创建一个新的 FunctionPassManager,因此我们将添加到上一章中创建的一个函数中(InitializeModule()
)
void InitializeModuleAndManagers(void) {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create new pass and analysis managers.
TheFPM = std::make_unique<FunctionPassManager>();
TheLAM = std::make_unique<LoopAnalysisManager>();
TheFAM = std::make_unique<FunctionAnalysisManager>();
TheCGAM = std::make_unique<CGSCCAnalysisManager>();
TheMAM = std::make_unique<ModuleAnalysisManager>();
ThePIC = std::make_unique<PassInstrumentationCallbacks>();
TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
/*DebugLogging*/ true);
TheSI->registerCallbacks(*ThePIC, TheMAM.get());
...
在初始化全局模块TheModule
和 FunctionPassManager 后,我们需要初始化框架的其他部分。四个 AnalysisManagers 允许我们添加跨越 IR 层次结构四个级别的分析 Pass。PassInstrumentationCallbacks 和 StandardInstrumentations 是 Pass 检测框架所必需的,该框架允许开发人员自定义 Pass 之间发生的事情。
设置好这些管理器后,我们使用一系列“addPass”调用来添加一堆 LLVM 转换 Pass
// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());
在这种情况下,我们选择添加四个优化 Pass。我们在这里选择的 Pass 是一组相当标准的“清理”优化,对各种代码很有用。我不会深入探讨它们的作用,但相信我,它们是一个很好的起点 :)。
接下来,我们注册转换 Pass 使用的分析 Pass。
// Register analysis passes used in these transform passes.
PassBuilder PB;
PB.registerModuleAnalyses(*TheMAM);
PB.registerFunctionAnalyses(*TheFAM);
PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}
设置好 PassManager 后,我们需要使用它。我们通过在我们的新创建的函数构建后(在FunctionAST::codegen()
中)但在将其返回给客户端之前运行它来做到这一点
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder.CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Optimize the function.
TheFPM->run(*TheFunction, *TheFAM);
return TheFunction;
}
如您所见,这非常简单。FunctionPassManager
会就地优化和更新 LLVM Function*,从而改进(希望如此)其主体。有了这个,我们可以再次尝试上面的测试
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp
ret double %multmp
}
正如预期的那样,我们现在得到了经过良好优化的代码,在每次执行此函数时都节省了一条浮点加法指令。
LLVM 提供了各种可以在特定情况下使用的优化。一些有关各种 Pass 的文档可用,但并不完整。另一个很好的想法来源可以来自查看Clang
运行的 Pass 以开始使用。“opt
”工具允许您从命令行试验 Pass,因此您可以查看它们是否执行任何操作。
现在我们已经从前端获得了合理的代码,让我们谈谈如何执行它!
4.4. 添加 JIT 编译器¶
LLVM IR 中可用的代码可以应用各种工具。例如,您可以对其运行优化(如上所述),可以以文本或二进制形式将其转储,可以将代码编译为某个目标的汇编文件 (.s),或者可以 JIT 编译它。LLVM IR 表示的优点在于它是编译器许多不同部分之间的“通用货币”。
在本节中,我们将向我们的解释器添加 JIT 编译器支持。我们希望 Kaleidoscope 的基本思想是让用户像现在一样输入函数体,但立即评估他们输入的顶级表达式。例如,如果他们输入“1 + 2;”,我们应该评估并打印出 3。如果他们定义了一个函数,他们应该能够从命令行调用它。
为了做到这一点,我们首先准备环境以创建当前本机目标的代码并声明和初始化 JIT。这是通过调用一些InitializeNativeTarget\*
函数并添加一个全局变量TheJIT
以及在main
中初始化它来完成的
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = std::make_unique<KaleidoscopeJIT>();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}
我们还需要为 JIT 设置数据布局
void InitializeModuleAndPassManager(void) {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("my cool jit", TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create a new pass manager attached to it.
TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
...
KaleidoscopeJIT 类是一个专门为这些教程构建的简单 JIT,可在 LLVM 源代码中找到,位于llvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.h。在后面的章节中,我们将研究它的工作原理并为其添加新功能,但现在我们将认为它是给定的。它的 API 非常简单:addModule
将 LLVM IR 模块添加到 JIT 中,使其函数可供执行(其内存由ResourceTracker
管理);并且lookup
允许我们查找已编译代码的指针。
我们可以利用这个简单的 API 并更改解析顶级表达式的代码,使其如下所示
static ExitOnError ExitOnErr;
...
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndPassManager();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
assert(ExprSymbol && "Function not found");
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
如果解析和代码生成成功,下一步就是将包含顶层表达式的模块添加到 JIT 中。我们通过调用 addModule 来实现这一点,它会触发模块中所有函数的代码生成,并接收一个 ResourceTracker
,稍后可用于从 JIT 中移除该模块。一旦模块被添加到 JIT 中,它就无法再修改,因此我们还通过调用 InitializeModuleAndPassManager()
打开一个新模块来保存后续代码。
将模块添加到 JIT 后,我们需要获取指向最终生成代码的指针。我们通过调用 JIT 的 lookup
方法并传递顶层表达式函数的名称来实现:__anon_expr
。由于我们刚刚添加了此函数,因此我们断言 lookup
返回了一个结果。
接下来,我们通过调用符号上的 getAddress()
获取 __anon_expr
函数的内存地址。回想一下,我们将顶层表达式编译成一个自包含的 LLVM 函数,该函数不接受任何参数并返回计算出的双精度数。由于 LLVM JIT 编译器与本机平台 ABI 匹配,这意味着您可以直接将结果指针转换为该类型函数的指针并直接调用它。这意味着 JIT 编译的代码与静态链接到应用程序中的本机机器代码之间没有区别。
最后,由于我们不支持重新评估顶层表达式,因此在完成时我们将模块从 JIT 中移除以释放关联的内存。但是,请记住,我们几行前创建的模块(通过 InitializeModuleAndPassManager
)仍然处于打开状态,等待添加新代码。
仅通过这两处修改,让我们看看 Kaleidoscope 现在如何工作!
ready> 4+5;
Read top-level expression:
define double @0() {
entry:
ret double 9.000000e+00
}
Evaluated to 9.000000
看起来它基本上可以工作了。函数的转储显示了我们为每个输入的顶层表达式合成的“不带参数且始终返回双精度数的函数”。这演示了非常基本的功能,但我们能做更多吗?
ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
%multmp = fmul double %y, 2.000000e+00
%addtmp = fadd double %multmp, %x
ret double %addtmp
}
ready> testfunc(4, 10);
Read top-level expression:
define double @1() {
entry:
%calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
ret double %calltmp
}
Evaluated to 24.000000
ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!
函数定义和调用也可以工作,但在最后一行出现了一些严重错误。调用看起来有效,那么发生了什么?正如您可能从 API 中猜到的那样,模块是 JIT 的分配单元,而 testfunc 是与匿名表达式位于同一模块的一部分。当我们从 JIT 中移除该模块以释放匿名表达式的内存时,我们也删除了 testfunc
的定义。然后,当我们第二次尝试调用 testfunc 时,JIT 无法再找到它。
解决此问题的最简单方法是将匿名表达式放在与其余函数定义不同的模块中。只要每个被调用的函数都有原型并在调用之前添加到 JIT 中,JIT 就会很乐意跨模块边界解析函数调用。通过将匿名表达式放在不同的模块中,我们可以删除它而不会影响其余函数。
事实上,我们将更进一步,并将每个函数放在其自己的模块中。这样做可以利用 KaleidoscopeJIT 的一个有用特性,这将使我们的环境更像 REPL:函数可以多次添加到 JIT 中(与每个函数必须具有唯一定义的模块不同)。当您在 KaleidoscopeJIT 中查找符号时,它将始终返回最新的定义。
ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 1.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 2.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 4.000000
为了允许每个函数在其自己的模块中生存,我们需要一种方法将先前的函数声明重新生成到我们打开的每个新模块中。
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
...
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
...
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
为了实现这一点,我们将首先添加一个新的全局变量 FunctionProtos
,它保存每个函数的最新原型。我们还将添加一个便捷方法 getFunction()
来替换对 TheModule->getFunction()
的调用。我们的便捷方法在 TheModule
中搜索现有的函数声明,如果找不到,则回退到从 FunctionProtos 生成新的声明。在 CallExprAST::codegen()
中,我们只需要替换对 TheModule->getFunction()
的调用。在 FunctionAST::codegen()
中,我们需要先更新 FunctionProtos 映射,然后调用 getFunction()
。完成此操作后,我们始终可以在当前模块中获取任何先前声明函数的函数声明。
我们还需要更新 HandleDefinition 和 HandleExtern。
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndPassManager();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
在 HandleDefinition 中,我们添加了两行代码以将新定义的函数传输到 JIT 并打开一个新模块。在 HandleExtern 中,我们只需要添加一行代码以将原型添加到 FunctionProtos。
警告
不允许在单独的模块中重复符号,因为 LLVM-9 中是这样规定的。这意味着您不能像下面显示的那样在 Kaleidoscope 中重新定义函数。只需跳过此部分。
原因是较新的 OrcV2 JIT API 试图非常接近静态和动态链接器规则,包括拒绝重复符号。要求符号名称唯一,使我们能够使用(唯一)符号名称作为键来支持符号的并发编译,以便进行跟踪。
完成这些更改后,让我们再次尝试我们的 REPL(这次我删除了匿名函数的转储,您应该已经了解了)。
ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000
它可以工作了!
即使使用这段简单的代码,我们也获得了一些令人惊讶的强大功能 - 请查看。
ready> extern sin(x);
Read extern:
declare double @sin(double)
ready> extern cos(x);
Read extern:
declare double @cos(double)
ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
ret double 0x3FEAED548F090CEE
}
Evaluated to 0.841471
ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
%calltmp = call double @sin(double %x)
%multmp = fmul double %calltmp, %calltmp
%calltmp2 = call double @cos(double %x)
%multmp4 = fmul double %calltmp2, %calltmp2
%addtmp = fadd double %multmp, %multmp4
ret double %addtmp
}
ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
%calltmp = call double @foo(double 4.000000e+00)
ret double %calltmp
}
Evaluated to 1.000000
哇,JIT 如何知道 sin 和 cos?答案出奇地简单:KaleidoscopeJIT 有一个简单的符号解析规则,它用于查找任何给定模块中不可用的符号:首先,它搜索所有已添加到 JIT 的模块,从最新的到最旧的,以查找最新的定义。如果在 JIT 内部找不到定义,它将回退到在 Kaleidoscope 进程本身调用“dlsym("sin")
”。由于“sin
”在 JIT 的地址空间中定义,因此它只需修补模块中的调用以直接调用 libm 版本的 sin
。但在某些情况下,这甚至会更进一步:由于 sin 和 cos 是标准数学函数的名称,因此当使用常量(如上面的“sin(1.0)
”)调用时,常量折叠器会直接将函数调用评估为正确的结果。
将来,我们将了解如何调整此符号解析规则可用于启用各种有用的功能,从安全性(限制 JIT 代码可用的符号集)到基于符号名称的动态代码生成,甚至延迟编译。
符号解析规则的一个直接好处是,我们现在可以通过编写任意 C++ 代码来实现操作来扩展语言。例如,如果我们添加
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
请注意,对于 Windows,我们需要实际导出这些函数,因为动态符号加载器将使用 GetProcAddress
来查找符号。
现在,我们可以通过使用诸如“extern putchard(x); putchard(120);
”之类的东西生成简单的控制台输出,它会在控制台上打印小写字母“x”(120 是“x”的 ASCII 代码)。类似的代码可用于在 Kaleidoscope 中实现文件 I/O、控制台输入和许多其他功能。
这完成了 Kaleidoscope 教程的 JIT 和优化器章节。此时,我们可以编译一种非图灵完备的编程语言,并以用户驱动的方式对其进行优化和 JIT 编译。接下来,我们将探讨 使用控制流构造扩展语言,同时解决一些有趣的 LLVM IR 问题。
4.5. 完整代码列表¶
以下是我们正在运行的示例的完整代码列表,并增强了 LLVM JIT 和优化器。要构建此示例,请使用
# Compile
clang++ -g toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native` -O3 -o toy
# Run
./toy
如果您在 Linux 上编译此代码,请确保也添加“-rdynamic”选项。这确保了在运行时正确解析外部函数。
以下是代码
#include "../include/KaleidoscopeJIT.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/Reassociate.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
using namespace llvm::orc;
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
tok_eof = -1,
// commands
tok_def = -2,
tok_extern = -3,
// primary
tok_identifier = -4,
tok_number = -5
};
static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal; // Filled in if tok_number
/// gettok - Return the next token from standard input.
static int gettok() {
static int LastChar = ' ';
// Skip any whitespace.
while (isspace(LastChar))
LastChar = getchar();
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
IdentifierStr = LastChar;
while (isalnum((LastChar = getchar())))
IdentifierStr += LastChar;
if (IdentifierStr == "def")
return tok_def;
if (IdentifierStr == "extern")
return tok_extern;
return tok_identifier;
}
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
std::string NumStr;
do {
NumStr += LastChar;
LastChar = getchar();
} while (isdigit(LastChar) || LastChar == '.');
NumVal = strtod(NumStr.c_str(), nullptr);
return tok_number;
}
if (LastChar == '#') {
// Comment until end of line.
do
LastChar = getchar();
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
if (LastChar != EOF)
return gettok();
}
// Check for end of file. Don't eat the EOF.
if (LastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
int ThisChar = LastChar;
LastChar = getchar();
return ThisChar;
}
//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//
namespace {
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string Name;
public:
VariableExprAST(const std::string &Name) : Name(Name) {}
Value *codegen() override;
};
/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;
public:
BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
Value *codegen() override;
};
/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;
public:
CallExprAST(const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: Callee(Callee), Args(std::move(Args)) {}
Value *codegen() override;
};
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
std::string Name;
std::vector<std::string> Args;
public:
PrototypeAST(const std::string &Name, std::vector<std::string> Args)
: Name(Name), Args(std::move(Args)) {}
Function *codegen();
const std::string &getName() const { return Name; }
};
/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprAST> Body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprAST> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
Function *codegen();
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
/// token the parser is looking at. getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }
/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;
/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
if (!isascii(CurTok))
return -1;
// Make sure it's a declared binop.
int TokPrec = BinopPrecedence[CurTok];
if (TokPrec <= 0)
return -1;
return TokPrec;
}
/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
fprintf(stderr, "Error: %s\n", Str);
return nullptr;
}
std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
LogError(Str);
return nullptr;
}
static std::unique_ptr<ExprAST> ParseExpression();
/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
auto Result = std::make_unique<NumberExprAST>(NumVal);
getNextToken(); // consume the number
return std::move(Result);
}
/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
getNextToken(); // eat (.
auto V = ParseExpression();
if (!V)
return nullptr;
if (CurTok != ')')
return LogError("expected ')'");
getNextToken(); // eat ).
return V;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
std::string IdName = IdentifierStr;
getNextToken(); // eat identifier.
if (CurTok != '(') // Simple variable ref.
return std::make_unique<VariableExprAST>(IdName);
// Call.
getNextToken(); // eat (
std::vector<std::unique_ptr<ExprAST>> Args;
if (CurTok != ')') {
while (true) {
if (auto Arg = ParseExpression())
Args.push_back(std::move(Arg));
else
return nullptr;
if (CurTok == ')')
break;
if (CurTok != ',')
return LogError("Expected ')' or ',' in argument list");
getNextToken();
}
}
// Eat the ')'.
getNextToken();
return std::make_unique<CallExprAST>(IdName, std::move(Args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
switch (CurTok) {
default:
return LogError("unknown token when expecting an expression");
case tok_identifier:
return ParseIdentifierExpr();
case tok_number:
return ParseNumberExpr();
case '(':
return ParseParenExpr();
}
}
/// binoprhs
/// ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
std::unique_ptr<ExprAST> LHS) {
// If this is a binop, find its precedence.
while (true) {
int TokPrec = GetTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (TokPrec < ExprPrec)
return LHS;
// Okay, we know this is a binop.
int BinOp = CurTok;
getNextToken(); // eat binop
// Parse the primary expression after the binary operator.
auto RHS = ParsePrimary();
if (!RHS)
return nullptr;
// If BinOp binds less tightly with RHS than the operator after RHS, let
// the pending operator take RHS as its LHS.
int NextPrec = GetTokPrecedence();
if (TokPrec < NextPrec) {
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
if (!RHS)
return nullptr;
}
// Merge LHS/RHS.
LHS =
std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
}
}
/// expression
/// ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
auto LHS = ParsePrimary();
if (!LHS)
return nullptr;
return ParseBinOpRHS(0, std::move(LHS));
}
/// prototype
/// ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
if (CurTok != tok_identifier)
return LogErrorP("Expected function name in prototype");
std::string FnName = IdentifierStr;
getNextToken();
if (CurTok != '(')
return LogErrorP("Expected '(' in prototype");
std::vector<std::string> ArgNames;
while (getNextToken() == tok_identifier)
ArgNames.push_back(IdentifierStr);
if (CurTok != ')')
return LogErrorP("Expected ')' in prototype");
// success.
getNextToken(); // eat ')'.
return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}
/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
getNextToken(); // eat def.
auto Proto = ParsePrototype();
if (!Proto)
return nullptr;
if (auto E = ParseExpression())
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
return nullptr;
}
/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
if (auto E = ParseExpression()) {
// Make an anonymous proto.
auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
std::vector<std::string>());
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
}
/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
getNextToken(); // eat extern.
return ParsePrototype();
}
//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
static std::unique_ptr<FunctionPassManager> TheFPM;
static std::unique_ptr<LoopAnalysisManager> TheLAM;
static std::unique_ptr<FunctionAnalysisManager> TheFAM;
static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
static std::unique_ptr<ModuleAnalysisManager> TheMAM;
static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
static std::unique_ptr<StandardInstrumentations> TheSI;
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
static ExitOnError ExitOnErr;
Value *LogErrorV(const char *Str) {
LogError(Str);
return nullptr;
}
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
Value *VariableExprAST::codegen() {
// Look this variable up in the function.
Value *V = NamedValues[Name];
if (!V)
return LogErrorV("Unknown variable name");
return V;
}
Value *BinaryExprAST::codegen() {
Value *L = LHS->codegen();
Value *R = RHS->codegen();
if (!L || !R)
return nullptr;
switch (Op) {
case '+':
return Builder->CreateFAdd(L, R, "addtmp");
case '-':
return Builder->CreateFSub(L, R, "subtmp");
case '*':
return Builder->CreateFMul(L, R, "multmp");
case '<':
L = Builder->CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
default:
return LogErrorV("invalid binary operator");
}
}
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
if (!CalleeF)
return LogErrorV("Unknown function referenced");
// If argument mismatch error.
if (CalleeF->arg_size() != Args.size())
return LogErrorV("Incorrect # arguments passed");
std::vector<Value *> ArgsV;
for (unsigned i = 0, e = Args.size(); i != e; ++i) {
ArgsV.push_back(Args[i]->codegen());
if (!ArgsV.back())
return nullptr;
}
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
Function *PrototypeAST::codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
FunctionType *FT =
FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
// Set names for all arguments.
unsigned Idx = 0;
for (auto &Arg : F->args())
Arg.setName(Args[Idx++]);
return F;
}
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
Builder->SetInsertPoint(BB);
// Record the function arguments in the NamedValues map.
NamedValues.clear();
for (auto &Arg : TheFunction->args())
NamedValues[std::string(Arg.getName())] = &Arg;
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder->CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Run the optimizer on the function.
TheFPM->run(*TheFunction, *TheFAM);
return TheFunction;
}
// Error reading body, remove function.
TheFunction->eraseFromParent();
return nullptr;
}
//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//
static void InitializeModuleAndManagers() {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create new pass and analysis managers.
TheFPM = std::make_unique<FunctionPassManager>();
TheLAM = std::make_unique<LoopAnalysisManager>();
TheFAM = std::make_unique<FunctionAnalysisManager>();
TheCGAM = std::make_unique<CGSCCAnalysisManager>();
TheMAM = std::make_unique<ModuleAnalysisManager>();
ThePIC = std::make_unique<PassInstrumentationCallbacks>();
TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
/*DebugLogging*/ true);
TheSI->registerCallbacks(*ThePIC, TheMAM.get());
// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());
// Register analysis passes used in these transform passes.
PassBuilder PB;
PB.registerModuleAnalyses(*TheMAM);
PB.registerFunctionAnalyses(*TheFAM);
PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndManagers();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndManagers();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
/// top ::= definition | external | expression | ';'
static void MainLoop() {
while (true) {
fprintf(stderr, "ready> ");
switch (CurTok) {
case tok_eof:
return;
case ';': // ignore top-level semicolons.
getNextToken();
break;
case tok_def:
HandleDefinition();
break;
case tok_extern:
HandleExtern();
break;
default:
HandleTopLevelExpression();
break;
}
}
}
//===----------------------------------------------------------------------===//
// "Library" functions that can be "extern'd" from user code.
//===----------------------------------------------------------------------===//
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
/// printd - printf that takes a double prints it as "%f\n", returning 0.
extern "C" DLLEXPORT double printd(double X) {
fprintf(stderr, "%f\n", X);
return 0;
}
//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
InitializeModuleAndManagers();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}