4. 万花筒:添加 JIT 和优化器支持¶
4.1. 第 4 章 简介¶
欢迎来到“使用 LLVM 实现语言”教程的第 4 章。第 1-3 章介绍了简单语言的实现,并添加了对生成 LLVM IR 的支持。本章介绍两种新技术:为您的语言添加优化器支持,以及添加 JIT 编译器支持。这些添加将演示如何为 Kaleidoscope 语言获得良好、高效的代码。
4.2. 简单的常量折叠¶
我们在第 3 章中的演示优雅且易于扩展。不幸的是,它不会生成出色的代码。然而,IRBuilder 在编译简单代码时确实为我们提供了明显的优化
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
ret double %addtmp
}
此代码不是解析输入构建的 AST 的字面转录。那将是
ready> def test(x) 1+2+x;
Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 2.000000e+00, 1.000000e+00
%addtmp1 = fadd double %addtmp, %x
ret double %addtmp1
}
如上所示,常量折叠尤其是一种非常常见且非常重要的优化:如此重要,以至于许多语言实现者在其 AST 表示中实现了常量折叠支持。
使用 LLVM,您无需在 AST 中获得此支持。由于所有构建 LLVM IR 的调用都通过 LLVM IR 构建器,因此构建器本身会检查在您调用它时是否存在常量折叠机会。如果存在,它只会进行常量折叠并返回常量,而不是创建指令。
嗯,这很简单 :)。在实践中,我们建议在生成这样的代码时始终使用 IRBuilder
。它的使用没有“语法开销”(您不必到处用常量检查来丑化您的编译器),并且在某些情况下(特别是对于带有宏预处理器或使用大量常量的语言),它可以显着减少生成的 LLVM IR 的数量。
另一方面,IRBuilder
受限于它在代码构建时进行所有内联分析的事实。如果您采用稍微复杂的示例
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double 3.000000e+00, %x
%addtmp1 = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp1
ret double %multmp
}
在这种情况下,乘法的 LHS 和 RHS 是相同的值。我们真的希望看到它生成“tmp = x+3; result = tmp*tmp;
”而不是计算两次“x+3
”。
不幸的是,任何数量的本地分析都无法检测和纠正这一点。这需要两个转换:表达式的重新关联(使加法在词法上相同)和公共子表达式消除 (CSE) 以删除冗余的加法指令。幸运的是,LLVM 提供了广泛的优化,您可以以“pass”的形式使用它们。
4.3. LLVM 优化 Pass¶
LLVM 提供了许多优化 pass,它们执行许多不同种类的事情,并且具有不同的权衡。与其他系统不同,LLVM 不坚持认为一组优化适用于所有语言和所有情况的错误观念。LLVM 允许编译器实现者完全决定使用哪些优化,以什么顺序以及在什么情况下使用。
作为一个具体的例子,LLVM 同时支持“整个模块”pass,它尽可能地跨越大量的代码(通常是一个完整的文件,但如果在链接时运行,这可能是整个程序的重要部分)。它还支持并包含“按函数”pass,这些 pass 仅一次对单个函数进行操作,而不查看其他函数。有关 pass 以及如何运行它们的更多信息,请参阅如何编写 Pass文档和LLVM Pass 列表。
对于 Kaleidoscope,我们目前正在动态生成函数,一次一个,就像用户键入它们一样。我们并非在追求此设置中的最终优化体验,但我们也希望在可能的情况下捕获简单而快速的东西。因此,我们将选择在用户键入函数时运行一些按函数优化。如果我们想制作“静态 Kaleidoscope 编译器”,我们将完全使用我们现在的代码,只不过我们会推迟运行优化器,直到整个文件都已解析。
除了函数 pass 和模块 pass 之间的区别外,pass 还可以分为转换 pass 和分析 pass。转换 pass 改变 IR,分析 pass 计算其他 pass 可以使用的信息。为了添加转换 pass,它所依赖的所有分析 pass 都必须提前注册。
为了使按函数优化开始运行,我们需要设置一个 FunctionPassManager 来保存和组织我们想要运行的 LLVM 优化。一旦我们有了它,我们就可以添加一组优化来运行。对于我们要优化的每个模块,我们需要一个新的 FunctionPassManager,因此我们将添加到上一章中创建的函数中 (InitializeModule()
)
void InitializeModuleAndManagers(void) {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create new pass and analysis managers.
TheFPM = std::make_unique<FunctionPassManager>();
TheLAM = std::make_unique<LoopAnalysisManager>();
TheFAM = std::make_unique<FunctionAnalysisManager>();
TheCGAM = std::make_unique<CGSCCAnalysisManager>();
TheMAM = std::make_unique<ModuleAnalysisManager>();
ThePIC = std::make_unique<PassInstrumentationCallbacks>();
TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
/*DebugLogging*/ true);
TheSI->registerCallbacks(*ThePIC, TheMAM.get());
...
在初始化全局模块 TheModule
和 FunctionPassManager 之后,我们需要初始化框架的其他部分。四个 AnalysisManager 允许我们添加跨越 IR 层级结构的四个级别的分析 pass。PassInstrumentationCallbacks 和 StandardInstrumentations 是 pass instrumentation 框架所必需的,该框架允许开发人员自定义 pass 之间发生的事情。
一旦设置好这些管理器,我们就使用一系列“addPass”调用来添加一堆 LLVM 转换 pass
// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());
在本例中,我们选择添加四个优化 pass。我们在此处选择的 pass 是一组非常标准的“清理”优化,这些优化对于各种代码都很有用。我不会深入研究它们的作用,但请相信我,它们是一个很好的起点 :)。
接下来,我们注册转换 pass 使用的分析 pass。
// Register analysis passes used in these transform passes.
PassBuilder PB;
PB.registerModuleAnalyses(*TheMAM);
PB.registerFunctionAnalyses(*TheFAM);
PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}
设置好 PassManager 后,我们需要使用它。我们通过在新建的函数构建后(在 FunctionAST::codegen()
中),但在将其返回给客户端之前运行它来实现这一点
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder.CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Optimize the function.
TheFPM->run(*TheFunction, *TheFAM);
return TheFunction;
}
如您所见,这非常简单。 FunctionPassManager
就地优化和更新 LLVM Function*,从而(希望)改进其主体。有了这个,我们可以再次尝试上面的测试
ready> def test(x) (1+2+x)*(x+(1+2));
ready> Read function definition:
define double @test(double %x) {
entry:
%addtmp = fadd double %x, 3.000000e+00
%multmp = fmul double %addtmp, %addtmp
ret double %multmp
}
正如预期的那样,我们现在得到了经过良好优化的代码,从而在每次执行此函数时节省了一个浮点加法指令。
LLVM 提供了各种各样的优化,可以在某些情况下使用。一些关于各种 pass 的文档可用,但不是很完整。另一个好的想法来源可以来自查看 Clang
运行的 pass 以开始。 “opt
”工具允许您从命令行试验 pass,以便您查看它们是否执行任何操作。
现在我们已经从前端获得了合理的代码,让我们来谈谈执行它!
4.4. 添加 JIT 编译器¶
LLVM IR 中可用的代码可以应用各种各样的工具。例如,您可以对其运行优化(如我们上面所做的那样),您可以以文本或二进制形式转储它,您可以将代码编译为某些目标的汇编文件 (.s),或者您可以 JIT 编译它。关于 LLVM IR 表示形式的好处是,它是编译器许多不同部分之间的“通用货币”。
在本节中,我们将为我们的解释器添加 JIT 编译器支持。我们希望 Kaleidoscope 的基本想法是让用户像现在一样输入函数体,但立即评估他们键入的顶层表达式。例如,如果他们键入“1 + 2;”,我们应该评估并打印出 3。如果他们定义了一个函数,他们应该能够从命令行调用它。
为了做到这一点,我们首先准备环境,为当前本机目标创建代码,并声明和初始化 JIT。这是通过调用一些 InitializeNativeTarget\*
函数并添加一个全局变量 TheJIT
,并在 main
中初始化它来完成的
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = std::make_unique<KaleidoscopeJIT>();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}
我们还需要为 JIT 设置数据布局
void InitializeModuleAndPassManager(void) {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("my cool jit", TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create a new pass manager attached to it.
TheFPM = std::make_unique<legacy::FunctionPassManager>(TheModule.get());
...
KaleidoscopeJIT 类是一个专门为这些教程构建的简单 JIT,可在 LLVM 源代码中的 llvm-src/examples/Kaleidoscope/include/KaleidoscopeJIT.h 中找到。在后面的章节中,我们将了解它的工作原理并使用新功能对其进行扩展,但现在我们将把它视为给定的。它的 API 非常简单:addModule
将 LLVM IR 模块添加到 JIT,使其函数可用于执行(其内存由 ResourceTracker
管理);并且 lookup
允许我们查找指向已编译代码的指针。
我们可以采用这个简单的 API 并更改我们的代码,该代码解析顶层表达式以使其看起来像这样
static ExitOnError ExitOnErr;
...
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndPassManager();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
assert(ExprSymbol && "Function not found");
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = ExprSymbol.getAddress().toPtr<double (*)()>();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
如果解析和代码生成成功,则下一步是将包含顶层表达式的模块添加到 JIT。我们通过调用 addModule 来完成此操作,这将触发模块中所有函数的代码生成,并接受一个 ResourceTracker
,该 ResourceTracker 可用于稍后从 JIT 中删除模块。一旦模块被添加到 JIT,它就不能再被修改,因此我们还通过调用 InitializeModuleAndPassManager()
打开一个新模块来保存后续代码。
一旦我们将模块添加到 JIT,我们需要获取指向最终生成的代码的指针。我们通过调用 JIT 的 lookup
方法,并传递顶层表达式函数的名称来做到这一点:__anon_expr
。由于我们刚刚添加了这个函数,我们断言 lookup
返回了一个结果。
接下来,我们通过在符号上调用 getAddress()
来获取 __anon_expr
函数的内存地址。回想一下,我们将顶层表达式编译成一个独立的 LLVM 函数,该函数不带参数并返回计算出的 double。由于 LLVM JIT 编译器与本机平台 ABI 匹配,这意味着您可以将结果指针强制转换为该类型的函数指针并直接调用它。这意味着,JIT 编译代码和静态链接到您的应用程序中的本机机器代码之间没有区别。
最后,由于我们不支持顶层表达式的重新评估,因此我们在完成后从 JIT 中删除模块以释放关联的内存。但是,回想一下,我们前几行创建的模块(通过 InitializeModuleAndPassManager
)仍然处于打开状态,等待添加新代码。
仅需这两个更改,让我们看看 Kaleidoscope 现在是如何工作的!
ready> 4+5;
Read top-level expression:
define double @0() {
entry:
ret double 9.000000e+00
}
Evaluated to 9.000000
嗯,这看起来基本上可以工作了。函数的转储显示了我们为键入的每个顶层表达式合成的“始终返回 double 的无参数函数”。这演示了非常基本的功能,但是我们能做更多的事情吗?
ready> def testfunc(x y) x + y*2;
Read function definition:
define double @testfunc(double %x, double %y) {
entry:
%multmp = fmul double %y, 2.000000e+00
%addtmp = fadd double %multmp, %x
ret double %addtmp
}
ready> testfunc(4, 10);
Read top-level expression:
define double @1() {
entry:
%calltmp = call double @testfunc(double 4.000000e+00, double 1.000000e+01)
ret double %calltmp
}
Evaluated to 24.000000
ready> testfunc(5, 10);
ready> LLVM ERROR: Program used external function 'testfunc' which could not be resolved!
函数定义和调用也有效,但是最后一行出了问题。调用看起来有效,那么发生了什么?正如您可能从 API 中猜到的那样,Module 是 JIT 的分配单元,而 testfunc 是包含匿名表达式的同一模块的一部分。当我们从 JIT 中删除该模块以释放匿名表达式的内存时,我们同时删除了 testfunc
的定义。然后,当我们第二次尝试调用 testfunc 时,JIT 找不到它了。
解决此问题的最简单方法是将匿名表达式放在与其余函数定义不同的模块中。只要调用的每个函数都有原型,并且在调用之前添加到 JIT,JIT 就会很乐意跨模块边界解析函数调用。通过将匿名表达式放在不同的模块中,我们可以删除它而不会影响其余函数。
实际上,我们将更进一步,并将每个函数都放在自己的模块中。这样做使我们能够利用 KaleidoscopeJIT 的有用属性,这将使我们的环境更像 REPL:函数可以多次添加到 JIT(与模块不同,模块中每个函数都必须具有唯一的定义)。当您在 KaleidoscopeJIT 中查找符号时,它将始终返回最新的定义
ready> def foo(x) x + 1;
Read function definition:
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 1.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
define double @foo(double %x) {
entry:
%addtmp = fadd double %x, 2.000000e+00
ret double %addtmp
}
ready> foo(2);
Evaluated to 4.000000
为了允许每个函数都存在于其自己的模块中,我们需要一种方法来将先前的函数声明重新生成到我们打开的每个新模块中
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
...
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
...
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
...
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
为了启用此功能,我们将首先添加一个新的全局变量 FunctionProtos
,它保存每个函数的最新原型。我们还将添加一个方便的方法 getFunction()
,以替换对 TheModule->getFunction()
的调用。我们的便捷方法在 TheModule
中搜索现有的函数声明,如果找不到,则回退到从 FunctionProtos 生成新的声明。在 CallExprAST::codegen()
中,我们只需要替换对 TheModule->getFunction()
的调用。在 FunctionAST::codegen()
中,我们需要首先更新 FunctionProtos 映射,然后调用 getFunction()
。完成此操作后,我们始终可以在当前模块中获取任何先前声明的函数的函数声明。
我们还需要更新 HandleDefinition 和 HandleExtern
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndPassManager();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
在 HandleDefinition 中,我们添加了两行代码以将新定义的函数传输到 JIT 并打开一个新模块。在 HandleExtern 中,我们只需要添加一行代码将原型添加到 FunctionProtos。
警告
自 LLVM-9 以来,不允许在单独的模块中重复符号。这意味着您不能像下面所示那样在 Kaleidoscope 中重新定义函数。只需跳过此部分即可。
原因是较新的 OrcV2 JIT API 试图非常接近静态和动态链接器规则,包括拒绝重复符号。要求符号名称是唯一的,这使我们能够使用(唯一的)符号名称作为跟踪密钥来支持符号的并发编译。
完成这些更改后,让我们再次尝试我们的 REPL(这次我删除了匿名函数的转储,您现在应该了解了 :)
ready> def foo(x) x + 1;
ready> foo(2);
Evaluated to 3.000000
ready> def foo(x) x + 2;
ready> foo(2);
Evaluated to 4.000000
它工作了!
即使使用这段简单的代码,我们也能获得一些出乎意料的强大功能 - 看看这个
ready> extern sin(x);
Read extern:
declare double @sin(double)
ready> extern cos(x);
Read extern:
declare double @cos(double)
ready> sin(1.0);
Read top-level expression:
define double @2() {
entry:
ret double 0x3FEAED548F090CEE
}
Evaluated to 0.841471
ready> def foo(x) sin(x)*sin(x) + cos(x)*cos(x);
Read function definition:
define double @foo(double %x) {
entry:
%calltmp = call double @sin(double %x)
%multmp = fmul double %calltmp, %calltmp
%calltmp2 = call double @cos(double %x)
%multmp4 = fmul double %calltmp2, %calltmp2
%addtmp = fadd double %multmp, %multmp4
ret double %addtmp
}
ready> foo(4.0);
Read top-level expression:
define double @3() {
entry:
%calltmp = call double @foo(double 4.000000e+00)
ret double %calltmp
}
Evaluated to 1.000000
哇,JIT 是如何知道 sin 和 cos 的?答案出奇的简单:KaleidoscopeJIT 有一个直接的符号解析规则,它用于查找在任何给定模块中都不可用的符号:首先,它搜索已添加到 JIT 的所有模块,从最新的到最旧的,以查找最新的定义。如果在 JIT 内部找不到定义,它会回退到在 Kaleidoscope 进程本身上调用 “dlsym("sin")
”。由于 “sin
” 在 JIT 的地址空间内定义,因此它只是修补模块中的调用以直接调用 libm 版本的 sin
。但在某些情况下,这甚至更进一步:由于 sin 和 cos 是标准数学函数的名称,因此当使用常量(如上面的 “sin(1.0)
” 中)调用时,常量折叠器将直接评估函数调用以获得正确的结果。
将来,我们将看到如何调整此符号解析规则可用于启用各种有用的功能,从安全性(限制 JIT 代码可用的符号集)到基于符号名称的动态代码生成,甚至延迟编译。
符号解析规则的一个直接好处是,我们现在可以通过编写任意 C++ 代码来实现操作来扩展语言。例如,如果我们添加
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
请注意,对于 Windows,我们需要实际导出函数,因为动态符号加载器将使用 GetProcAddress
来查找符号。
现在我们可以通过使用诸如 “extern putchard(x); putchard(120);
” 之类的东西来生成简单的输出到控制台,这会在控制台上打印一个小写字母 'x'(120 是 'x' 的 ASCII 码)。类似的代码可用于在 Kaleidoscope 中实现文件 I/O、控制台输入和许多其他功能。
至此,Kaleidoscope 教程的 JIT 和优化器章节就完成了。在这一点上,我们可以编译一种非图灵完备的编程语言,以用户驱动的方式优化和 JIT 编译它。接下来,我们将研究 使用控制流结构扩展语言,并在此过程中解决一些有趣的 LLVM IR 问题。
4.5. 完整代码清单¶
这是我们正在运行的示例的完整代码清单,使用 LLVM JIT 和优化器进行了增强。要构建此示例,请使用
# Compile
clang++ -g toy.cpp `llvm-config --cxxflags --ldflags --system-libs --libs core orcjit native` -O3 -o toy
# Run
./toy
如果您在 Linux 上编译它,请确保也添加 “-rdynamic” 选项。这确保在运行时正确解析外部函数。
这是代码
#include "../include/KaleidoscopeJIT.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/StandardInstrumentations.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
#include "llvm/Transforms/Scalar.h"
#include "llvm/Transforms/Scalar/GVN.h"
#include "llvm/Transforms/Scalar/Reassociate.h"
#include "llvm/Transforms/Scalar/SimplifyCFG.h"
#include <algorithm>
#include <cassert>
#include <cctype>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
using namespace llvm;
using namespace llvm::orc;
//===----------------------------------------------------------------------===//
// Lexer
//===----------------------------------------------------------------------===//
// The lexer returns tokens [0-255] if it is an unknown character, otherwise one
// of these for known things.
enum Token {
tok_eof = -1,
// commands
tok_def = -2,
tok_extern = -3,
// primary
tok_identifier = -4,
tok_number = -5
};
static std::string IdentifierStr; // Filled in if tok_identifier
static double NumVal; // Filled in if tok_number
/// gettok - Return the next token from standard input.
static int gettok() {
static int LastChar = ' ';
// Skip any whitespace.
while (isspace(LastChar))
LastChar = getchar();
if (isalpha(LastChar)) { // identifier: [a-zA-Z][a-zA-Z0-9]*
IdentifierStr = LastChar;
while (isalnum((LastChar = getchar())))
IdentifierStr += LastChar;
if (IdentifierStr == "def")
return tok_def;
if (IdentifierStr == "extern")
return tok_extern;
return tok_identifier;
}
if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
std::string NumStr;
do {
NumStr += LastChar;
LastChar = getchar();
} while (isdigit(LastChar) || LastChar == '.');
NumVal = strtod(NumStr.c_str(), nullptr);
return tok_number;
}
if (LastChar == '#') {
// Comment until end of line.
do
LastChar = getchar();
while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
if (LastChar != EOF)
return gettok();
}
// Check for end of file. Don't eat the EOF.
if (LastChar == EOF)
return tok_eof;
// Otherwise, just return the character as its ascii value.
int ThisChar = LastChar;
LastChar = getchar();
return ThisChar;
}
//===----------------------------------------------------------------------===//
// Abstract Syntax Tree (aka Parse Tree)
//===----------------------------------------------------------------------===//
namespace {
/// ExprAST - Base class for all expression nodes.
class ExprAST {
public:
virtual ~ExprAST() = default;
virtual Value *codegen() = 0;
};
/// NumberExprAST - Expression class for numeric literals like "1.0".
class NumberExprAST : public ExprAST {
double Val;
public:
NumberExprAST(double Val) : Val(Val) {}
Value *codegen() override;
};
/// VariableExprAST - Expression class for referencing a variable, like "a".
class VariableExprAST : public ExprAST {
std::string Name;
public:
VariableExprAST(const std::string &Name) : Name(Name) {}
Value *codegen() override;
};
/// BinaryExprAST - Expression class for a binary operator.
class BinaryExprAST : public ExprAST {
char Op;
std::unique_ptr<ExprAST> LHS, RHS;
public:
BinaryExprAST(char Op, std::unique_ptr<ExprAST> LHS,
std::unique_ptr<ExprAST> RHS)
: Op(Op), LHS(std::move(LHS)), RHS(std::move(RHS)) {}
Value *codegen() override;
};
/// CallExprAST - Expression class for function calls.
class CallExprAST : public ExprAST {
std::string Callee;
std::vector<std::unique_ptr<ExprAST>> Args;
public:
CallExprAST(const std::string &Callee,
std::vector<std::unique_ptr<ExprAST>> Args)
: Callee(Callee), Args(std::move(Args)) {}
Value *codegen() override;
};
/// PrototypeAST - This class represents the "prototype" for a function,
/// which captures its name, and its argument names (thus implicitly the number
/// of arguments the function takes).
class PrototypeAST {
std::string Name;
std::vector<std::string> Args;
public:
PrototypeAST(const std::string &Name, std::vector<std::string> Args)
: Name(Name), Args(std::move(Args)) {}
Function *codegen();
const std::string &getName() const { return Name; }
};
/// FunctionAST - This class represents a function definition itself.
class FunctionAST {
std::unique_ptr<PrototypeAST> Proto;
std::unique_ptr<ExprAST> Body;
public:
FunctionAST(std::unique_ptr<PrototypeAST> Proto,
std::unique_ptr<ExprAST> Body)
: Proto(std::move(Proto)), Body(std::move(Body)) {}
Function *codegen();
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
/// CurTok/getNextToken - Provide a simple token buffer. CurTok is the current
/// token the parser is looking at. getNextToken reads another token from the
/// lexer and updates CurTok with its results.
static int CurTok;
static int getNextToken() { return CurTok = gettok(); }
/// BinopPrecedence - This holds the precedence for each binary operator that is
/// defined.
static std::map<char, int> BinopPrecedence;
/// GetTokPrecedence - Get the precedence of the pending binary operator token.
static int GetTokPrecedence() {
if (!isascii(CurTok))
return -1;
// Make sure it's a declared binop.
int TokPrec = BinopPrecedence[CurTok];
if (TokPrec <= 0)
return -1;
return TokPrec;
}
/// LogError* - These are little helper functions for error handling.
std::unique_ptr<ExprAST> LogError(const char *Str) {
fprintf(stderr, "Error: %s\n", Str);
return nullptr;
}
std::unique_ptr<PrototypeAST> LogErrorP(const char *Str) {
LogError(Str);
return nullptr;
}
static std::unique_ptr<ExprAST> ParseExpression();
/// numberexpr ::= number
static std::unique_ptr<ExprAST> ParseNumberExpr() {
auto Result = std::make_unique<NumberExprAST>(NumVal);
getNextToken(); // consume the number
return std::move(Result);
}
/// parenexpr ::= '(' expression ')'
static std::unique_ptr<ExprAST> ParseParenExpr() {
getNextToken(); // eat (.
auto V = ParseExpression();
if (!V)
return nullptr;
if (CurTok != ')')
return LogError("expected ')'");
getNextToken(); // eat ).
return V;
}
/// identifierexpr
/// ::= identifier
/// ::= identifier '(' expression* ')'
static std::unique_ptr<ExprAST> ParseIdentifierExpr() {
std::string IdName = IdentifierStr;
getNextToken(); // eat identifier.
if (CurTok != '(') // Simple variable ref.
return std::make_unique<VariableExprAST>(IdName);
// Call.
getNextToken(); // eat (
std::vector<std::unique_ptr<ExprAST>> Args;
if (CurTok != ')') {
while (true) {
if (auto Arg = ParseExpression())
Args.push_back(std::move(Arg));
else
return nullptr;
if (CurTok == ')')
break;
if (CurTok != ',')
return LogError("Expected ')' or ',' in argument list");
getNextToken();
}
}
// Eat the ')'.
getNextToken();
return std::make_unique<CallExprAST>(IdName, std::move(Args));
}
/// primary
/// ::= identifierexpr
/// ::= numberexpr
/// ::= parenexpr
static std::unique_ptr<ExprAST> ParsePrimary() {
switch (CurTok) {
default:
return LogError("unknown token when expecting an expression");
case tok_identifier:
return ParseIdentifierExpr();
case tok_number:
return ParseNumberExpr();
case '(':
return ParseParenExpr();
}
}
/// binoprhs
/// ::= ('+' primary)*
static std::unique_ptr<ExprAST> ParseBinOpRHS(int ExprPrec,
std::unique_ptr<ExprAST> LHS) {
// If this is a binop, find its precedence.
while (true) {
int TokPrec = GetTokPrecedence();
// If this is a binop that binds at least as tightly as the current binop,
// consume it, otherwise we are done.
if (TokPrec < ExprPrec)
return LHS;
// Okay, we know this is a binop.
int BinOp = CurTok;
getNextToken(); // eat binop
// Parse the primary expression after the binary operator.
auto RHS = ParsePrimary();
if (!RHS)
return nullptr;
// If BinOp binds less tightly with RHS than the operator after RHS, let
// the pending operator take RHS as its LHS.
int NextPrec = GetTokPrecedence();
if (TokPrec < NextPrec) {
RHS = ParseBinOpRHS(TokPrec + 1, std::move(RHS));
if (!RHS)
return nullptr;
}
// Merge LHS/RHS.
LHS =
std::make_unique<BinaryExprAST>(BinOp, std::move(LHS), std::move(RHS));
}
}
/// expression
/// ::= primary binoprhs
///
static std::unique_ptr<ExprAST> ParseExpression() {
auto LHS = ParsePrimary();
if (!LHS)
return nullptr;
return ParseBinOpRHS(0, std::move(LHS));
}
/// prototype
/// ::= id '(' id* ')'
static std::unique_ptr<PrototypeAST> ParsePrototype() {
if (CurTok != tok_identifier)
return LogErrorP("Expected function name in prototype");
std::string FnName = IdentifierStr;
getNextToken();
if (CurTok != '(')
return LogErrorP("Expected '(' in prototype");
std::vector<std::string> ArgNames;
while (getNextToken() == tok_identifier)
ArgNames.push_back(IdentifierStr);
if (CurTok != ')')
return LogErrorP("Expected ')' in prototype");
// success.
getNextToken(); // eat ')'.
return std::make_unique<PrototypeAST>(FnName, std::move(ArgNames));
}
/// definition ::= 'def' prototype expression
static std::unique_ptr<FunctionAST> ParseDefinition() {
getNextToken(); // eat def.
auto Proto = ParsePrototype();
if (!Proto)
return nullptr;
if (auto E = ParseExpression())
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
return nullptr;
}
/// toplevelexpr ::= expression
static std::unique_ptr<FunctionAST> ParseTopLevelExpr() {
if (auto E = ParseExpression()) {
// Make an anonymous proto.
auto Proto = std::make_unique<PrototypeAST>("__anon_expr",
std::vector<std::string>());
return std::make_unique<FunctionAST>(std::move(Proto), std::move(E));
}
return nullptr;
}
/// external ::= 'extern' prototype
static std::unique_ptr<PrototypeAST> ParseExtern() {
getNextToken(); // eat extern.
return ParsePrototype();
}
//===----------------------------------------------------------------------===//
// Code Generation
//===----------------------------------------------------------------------===//
static std::unique_ptr<LLVMContext> TheContext;
static std::unique_ptr<Module> TheModule;
static std::unique_ptr<IRBuilder<>> Builder;
static std::map<std::string, Value *> NamedValues;
static std::unique_ptr<KaleidoscopeJIT> TheJIT;
static std::unique_ptr<FunctionPassManager> TheFPM;
static std::unique_ptr<LoopAnalysisManager> TheLAM;
static std::unique_ptr<FunctionAnalysisManager> TheFAM;
static std::unique_ptr<CGSCCAnalysisManager> TheCGAM;
static std::unique_ptr<ModuleAnalysisManager> TheMAM;
static std::unique_ptr<PassInstrumentationCallbacks> ThePIC;
static std::unique_ptr<StandardInstrumentations> TheSI;
static std::map<std::string, std::unique_ptr<PrototypeAST>> FunctionProtos;
static ExitOnError ExitOnErr;
Value *LogErrorV(const char *Str) {
LogError(Str);
return nullptr;
}
Function *getFunction(std::string Name) {
// First, see if the function has already been added to the current module.
if (auto *F = TheModule->getFunction(Name))
return F;
// If not, check whether we can codegen the declaration from some existing
// prototype.
auto FI = FunctionProtos.find(Name);
if (FI != FunctionProtos.end())
return FI->second->codegen();
// If no existing prototype exists, return null.
return nullptr;
}
Value *NumberExprAST::codegen() {
return ConstantFP::get(*TheContext, APFloat(Val));
}
Value *VariableExprAST::codegen() {
// Look this variable up in the function.
Value *V = NamedValues[Name];
if (!V)
return LogErrorV("Unknown variable name");
return V;
}
Value *BinaryExprAST::codegen() {
Value *L = LHS->codegen();
Value *R = RHS->codegen();
if (!L || !R)
return nullptr;
switch (Op) {
case '+':
return Builder->CreateFAdd(L, R, "addtmp");
case '-':
return Builder->CreateFSub(L, R, "subtmp");
case '*':
return Builder->CreateFMul(L, R, "multmp");
case '<':
L = Builder->CreateFCmpULT(L, R, "cmptmp");
// Convert bool 0/1 to double 0.0 or 1.0
return Builder->CreateUIToFP(L, Type::getDoubleTy(*TheContext), "booltmp");
default:
return LogErrorV("invalid binary operator");
}
}
Value *CallExprAST::codegen() {
// Look up the name in the global module table.
Function *CalleeF = getFunction(Callee);
if (!CalleeF)
return LogErrorV("Unknown function referenced");
// If argument mismatch error.
if (CalleeF->arg_size() != Args.size())
return LogErrorV("Incorrect # arguments passed");
std::vector<Value *> ArgsV;
for (unsigned i = 0, e = Args.size(); i != e; ++i) {
ArgsV.push_back(Args[i]->codegen());
if (!ArgsV.back())
return nullptr;
}
return Builder->CreateCall(CalleeF, ArgsV, "calltmp");
}
Function *PrototypeAST::codegen() {
// Make the function type: double(double,double) etc.
std::vector<Type *> Doubles(Args.size(), Type::getDoubleTy(*TheContext));
FunctionType *FT =
FunctionType::get(Type::getDoubleTy(*TheContext), Doubles, false);
Function *F =
Function::Create(FT, Function::ExternalLinkage, Name, TheModule.get());
// Set names for all arguments.
unsigned Idx = 0;
for (auto &Arg : F->args())
Arg.setName(Args[Idx++]);
return F;
}
Function *FunctionAST::codegen() {
// Transfer ownership of the prototype to the FunctionProtos map, but keep a
// reference to it for use below.
auto &P = *Proto;
FunctionProtos[Proto->getName()] = std::move(Proto);
Function *TheFunction = getFunction(P.getName());
if (!TheFunction)
return nullptr;
// Create a new basic block to start insertion into.
BasicBlock *BB = BasicBlock::Create(*TheContext, "entry", TheFunction);
Builder->SetInsertPoint(BB);
// Record the function arguments in the NamedValues map.
NamedValues.clear();
for (auto &Arg : TheFunction->args())
NamedValues[std::string(Arg.getName())] = &Arg;
if (Value *RetVal = Body->codegen()) {
// Finish off the function.
Builder->CreateRet(RetVal);
// Validate the generated code, checking for consistency.
verifyFunction(*TheFunction);
// Run the optimizer on the function.
TheFPM->run(*TheFunction, *TheFAM);
return TheFunction;
}
// Error reading body, remove function.
TheFunction->eraseFromParent();
return nullptr;
}
//===----------------------------------------------------------------------===//
// Top-Level parsing and JIT Driver
//===----------------------------------------------------------------------===//
static void InitializeModuleAndManagers() {
// Open a new context and module.
TheContext = std::make_unique<LLVMContext>();
TheModule = std::make_unique<Module>("KaleidoscopeJIT", *TheContext);
TheModule->setDataLayout(TheJIT->getDataLayout());
// Create a new builder for the module.
Builder = std::make_unique<IRBuilder<>>(*TheContext);
// Create new pass and analysis managers.
TheFPM = std::make_unique<FunctionPassManager>();
TheLAM = std::make_unique<LoopAnalysisManager>();
TheFAM = std::make_unique<FunctionAnalysisManager>();
TheCGAM = std::make_unique<CGSCCAnalysisManager>();
TheMAM = std::make_unique<ModuleAnalysisManager>();
ThePIC = std::make_unique<PassInstrumentationCallbacks>();
TheSI = std::make_unique<StandardInstrumentations>(*TheContext,
/*DebugLogging*/ true);
TheSI->registerCallbacks(*ThePIC, TheMAM.get());
// Add transform passes.
// Do simple "peephole" optimizations and bit-twiddling optzns.
TheFPM->addPass(InstCombinePass());
// Reassociate expressions.
TheFPM->addPass(ReassociatePass());
// Eliminate Common SubExpressions.
TheFPM->addPass(GVNPass());
// Simplify the control flow graph (deleting unreachable blocks, etc).
TheFPM->addPass(SimplifyCFGPass());
// Register analysis passes used in these transform passes.
PassBuilder PB;
PB.registerModuleAnalyses(*TheMAM);
PB.registerFunctionAnalyses(*TheFAM);
PB.crossRegisterProxies(*TheLAM, *TheFAM, *TheCGAM, *TheMAM);
}
static void HandleDefinition() {
if (auto FnAST = ParseDefinition()) {
if (auto *FnIR = FnAST->codegen()) {
fprintf(stderr, "Read function definition:");
FnIR->print(errs());
fprintf(stderr, "\n");
ExitOnErr(TheJIT->addModule(
ThreadSafeModule(std::move(TheModule), std::move(TheContext))));
InitializeModuleAndManagers();
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleExtern() {
if (auto ProtoAST = ParseExtern()) {
if (auto *FnIR = ProtoAST->codegen()) {
fprintf(stderr, "Read extern: ");
FnIR->print(errs());
fprintf(stderr, "\n");
FunctionProtos[ProtoAST->getName()] = std::move(ProtoAST);
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
static void HandleTopLevelExpression() {
// Evaluate a top-level expression into an anonymous function.
if (auto FnAST = ParseTopLevelExpr()) {
if (FnAST->codegen()) {
// Create a ResourceTracker to track JIT'd memory allocated to our
// anonymous expression -- that way we can free it after executing.
auto RT = TheJIT->getMainJITDylib().createResourceTracker();
auto TSM = ThreadSafeModule(std::move(TheModule), std::move(TheContext));
ExitOnErr(TheJIT->addModule(std::move(TSM), RT));
InitializeModuleAndManagers();
// Search the JIT for the __anon_expr symbol.
auto ExprSymbol = ExitOnErr(TheJIT->lookup("__anon_expr"));
// Get the symbol's address and cast it to the right type (takes no
// arguments, returns a double) so we can call it as a native function.
double (*FP)() = ExprSymbol.toPtr<double (*)()>();
fprintf(stderr, "Evaluated to %f\n", FP());
// Delete the anonymous expression module from the JIT.
ExitOnErr(RT->remove());
}
} else {
// Skip token for error recovery.
getNextToken();
}
}
/// top ::= definition | external | expression | ';'
static void MainLoop() {
while (true) {
fprintf(stderr, "ready> ");
switch (CurTok) {
case tok_eof:
return;
case ';': // ignore top-level semicolons.
getNextToken();
break;
case tok_def:
HandleDefinition();
break;
case tok_extern:
HandleExtern();
break;
default:
HandleTopLevelExpression();
break;
}
}
}
//===----------------------------------------------------------------------===//
// "Library" functions that can be "extern'd" from user code.
//===----------------------------------------------------------------------===//
#ifdef _WIN32
#define DLLEXPORT __declspec(dllexport)
#else
#define DLLEXPORT
#endif
/// putchard - putchar that takes a double and returns 0.
extern "C" DLLEXPORT double putchard(double X) {
fputc((char)X, stderr);
return 0;
}
/// printd - printf that takes a double prints it as "%f\n", returning 0.
extern "C" DLLEXPORT double printd(double X) {
fprintf(stderr, "%f\n", X);
return 0;
}
//===----------------------------------------------------------------------===//
// Main driver code.
//===----------------------------------------------------------------------===//
int main() {
InitializeNativeTarget();
InitializeNativeTargetAsmPrinter();
InitializeNativeTargetAsmParser();
// Install standard binary operators.
// 1 is lowest precedence.
BinopPrecedence['<'] = 10;
BinopPrecedence['+'] = 20;
BinopPrecedence['-'] = 20;
BinopPrecedence['*'] = 40; // highest.
// Prime the first token.
fprintf(stderr, "ready> ");
getNextToken();
TheJIT = ExitOnErr(KaleidoscopeJIT::Create());
InitializeModuleAndManagers();
// Run the main "interpreter loop" now.
MainLoop();
return 0;
}