【译】如何使用 Python 创建一个虚拟机解释器?

赵斌

原文地址:Making a simple VM interpreter in Python

更新:根据大家的评论我对代码做了轻微的改动。感谢 robin-gvx、 bs4h 和 Dagur,具体代码见这里

Stack Machine 本身并没有任何的寄存器,它将所需要处理的值全部放入堆栈中而后进行处理。Stack Machine 虽然简单但是却十分强大,这也是为神马 Python,Java,PostScript,Forth 和其他语言都选择它作为自己的虚拟机的原因。

首先,我们先来谈谈堆栈。我们需要一个指令指针栈用于保存返回地址。这样当我们调用了一个子例程(比如调用一个函数)的时候我们就能够返回到我们开始调用的地方了。我们可以使用自修改代码(self-modifying code)来做这件事,恰如 Donald Knuth 发起的 MIX 所做的那样。但是如果这么做的话你不得不自己维护堆栈从而保证递归能正常工作。在这篇文章中,我并不会真正的实现子例程调用,但是要实现它其实并不难(可以考虑把实现它当成练习)。

有了堆栈之后你会省很多事儿。举个例子来说,考虑这样一个表达式 (2+3)*4。在 Stack Machine 上与这个表达式等价的代码为 2 3 + 4 *。首先,将 23 推入堆栈中,接下来的是操作符 +,此时让堆栈弹出这两个数值,再把它两加合之后的结果重新入栈。然后将 4 入堆,而后让堆栈弹出两个数值,再把他们相乘之后的结果重新入栈。多么简单啊!

让我们开始写一个简单的堆栈类吧。让这个类继承 collections.deque

from collections import deque

class Stack(deque):
    push = deque.append

    def top(self):
        return self[-1]

现在我们有了 pushpoptop 这三个方法。top 方法用于查看栈顶元素。

接下来,我们实现虚拟机这个类。在虚拟机中我们需要两个堆栈以及一些内存空间来存储程序本身(译者注:这里的程序请结合下文理解)。得益于 Pyhton 的动态类型我们可以往 list 中放入任何类型。唯一的问题是我们无法区分出哪些是字符串哪些是内置函数。正确的做法是只将真正的 Python 函数放入 list 中。我可能会在将来实现这一点。

我们同时还需要一个指令指针指向程序中下一个要执行的代码。

class Machine:
    def __init__(self, code):
        self.data_stack = Stack()
        self.return_addr_stack = Stack()
        self.instruction_pointer = 0
        self.code = code

这时候我们增加一些方便使用的函数省得以后多敲键盘。

def pop(self):
    return self.data_stack.pop()

def push(self, value):
    self.data_stack.push(value)

def top(self):
    return self.data_stack.top()

然后我们增加一个 dispatch 函数来完成每一个操作码做的事儿(我们并不是真正的使用操作码,只是动态展开它,你懂的)。首先,增加一个解释器所必须的循环:

def run(self):
    while self.instruction_pointer < len(self.code):
        opcode = self.code[self.instruction_pointer]
        self.instruction_pointer += 1
        self.dispatch(opcode)

诚如您所见的,这货只好好的做一件事儿,即获取下一条指令,让指令指针执自增,然后根据操作码分别处理。dispatch 函数的代码稍微长了一点。

def dispatch(self, op):
    dispatch_map = {
        "%":        self.mod,
        "*":        self.mul,
        "+":        self.plus,
        "-":        self.minus,
        "/":        self.div,
        "==":       self.eq,
        "cast_int": self.cast_int,
        "cast_str": self.cast_str,
        "drop":     self.drop,
        "dup":      self.dup,
        "if":       self.if_stmt,
        "jmp":      self.jmp,
        "over":     self.over,
        "print":    self.print_,
        "println":  self.println,
        "read":     self.read,
        "stack":    self.dump_stack,
        "swap":     self.swap,
    }

    if op in dispatch_map:
        dispatch_map[op]()
    elif isinstance(op, int):
        # push numbers on the data stack
        self.push(op)
    elif isinstance(op, str) and op[0]==op[-1]=='"':
        # push quoted strings on the data stack
        self.push(op[1:-1])
    else:
        raise RuntimeError("Unknown opcode: '%s'" % op)

基本上,这段代码只是根据操作码查找是都有对应的处理函数,例如 * 对应 self.muldrop 对应 self.dropdup 对应 self.dup。顺便说一句,你在这里看到的这段代码其实本质上就是简单版的 Forth。而且,Forth 语言还是值得您看看的。

总之捏,它一但发现操作码是 * 的话就直接调用 self.mul 并执行它。就像这样:

def mul(self):
    self.push(self.pop() * self.pop())

其他的函数也是类似这样的。如果我们在 dispatch_map 中查找不到相应操作函数,我们首先检查他是不是数字类型,如果是的话直接入栈;如果是被引号括起来的字符串的话也是同样处理--直接入栈。

截止现在,恭喜你,一个虚拟机就完成了。

让我们定义更多的操作,然后使用我们刚完成的虚拟机和p-code 语言 来写程序。

# Allow to use "print" as a name for our own method:
from __future__ import print_function

# ...

def plus(self):
    self.push(self.pop() + self.pop())

def minus(self):
    last = self.pop()
    self.push(self.pop() - last)

def mul(self):
    self.push(self.pop() * self.pop())

def div(self):
    last = self.pop()
    self.push(self.pop() / last)

def print(self):
    sys.stdout.write(str(self.pop()))
    sys.stdout.flush()

def println(self):
    sys.stdout.write("%s\n" % self.pop())
    sys.stdout.flush()

让我们用我们的虚拟机写个与 print((2+3)*4) 等同效果的例子。

Machine([2, 3, "+", 4, "*", "println"]).run()

你可以试着运行它。

现在引入一个新的操作 jump, 即 go-to 操作

def jmp(self):
    addr = self.pop()
    if isinstance(addr, int) and 0 <= addr < len(self.code):
        self.instruction_pointer = addr
    else:
        raise RuntimeError("JMP address must be a valid integer.")

它只改变指令指针的值。我们再看看分支跳转是怎么做的。

def if_stmt(self):
    false_clause = self.pop()
    true_clause = self.pop()
    test = self.pop()
    self.push(true_clause if test else false_clause)

这同样也是很直白的。如果你想要添加一个条件跳转,你只要简单的执行 test-value true-value false-value IF JMP 就可以了.(分支处理是很常见的操作,许多虚拟机都提供类似 JNE 这样的操作。JNEjump if not equal 的缩写)。

下面的程序要求使用者输入两个数字,然后打印出他们的和和乘积。

Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"Enter another number: "', "print", "read", "cast_int",
    "over", "over",
    '"Their sum is: "', "print", "+", "println",
    '"Their product is: "', "print", "*", "println"
]).run()

overreadcast_int 这三个操作是长这样滴:

def cast_int(self):
    self.push(int(self.pop()))

def over(self):
    b = self.pop()
    a = self.pop()
    self.push(a)
    self.push(b)
    self.push(a)

def read(self):
    self.push(raw_input())

以下这一段程序要求使用者输入一个数字,然后打印出这个数字是奇数还是偶数。

Machine([
    '"Enter a number: "', "print", "read", "cast_int",
    '"The number "', "print", "dup", "print", '" is "', "print",
    2, "%", 0, "==", '"even."', '"odd."', "if", "println",
    0, "jmp" # loop forever!
]).run()

这里有个小练习给你去实现:增加 callreturn 这两个操作码。call 操作码将会做如下事情 :将当前地址推入返回堆栈中,然后调用 self.jmp()return 操作码将会做如下事情:返回堆栈弹栈,将弹栈出来元素的值赋予指令指针(这个值可以让你跳转回去或者从 call 调用中返回)。当你完成这两个命令,那么你的虚拟机就可以调用子例程了。

一个简单的解析器

创造一个模仿上述程序的小型语言。我们将把它编译成我们的机器码。

import tokenize
from StringIO import StringIO

# ...

def parse(text):
    tokens = tokenize.generate_tokens(StringIO(text).readline)
    for toknum, tokval, _, _, _ in tokens:
        if toknum == tokenize.NUMBER:
            yield int(tokval)
        elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
            yield tokval
        elif toknum == tokenize.ENDMARKER:
            break
        else:
            raise RuntimeError("Unknown token %s: '%s'" %
                    (tokenize.tok_name[toknum], tokval))

一个简单的优化:常量折叠

常量折叠(Constant folding)是窥孔优化(peephole optimization)的一个例子,也即是说再在编译期间可以针对某些明显的代码片段做些预计算的工作。比如,对于涉及到常量的数学表达式例如 2 3 + 就可以很轻松的实现这种优化。

def constant_fold(code):
    """Constant-folds simple mathematical expressions like 2 3 + to 5."""
    while True:
        # Find two consecutive numbers and an arithmetic operator
        for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
            if isinstance(a, int) and isinstance(b, int) \
                    and op in {"+", "-", "*", "/"}:
                m = Machine((a, b, op))
                m.run()
                code[i:i+3] = [m.top()]
                print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
                break
        else:
            break
    return code

采用常量折叠遇到唯一问题就是我们不得不更新跳转地址,但在很多情况这是很难办到的(例如:test cast_int jmp)。针对这个问题有很多解决方法,其中一个简单的方法就是只允许跳转到程序中的命名标签上,然后在优化之后解析出他们真正的地址。

如果你实现了 Forth words,也即函数,你可以做更多的优化,比如删除可能永远不会被用到的程序代码(dead code elimination

REPL

我们可以创造一个简单的 PERL,就像这样

def repl():
    print('Hit CTRL+D or type "exit" to quit.')

    while True:
        try:
            source = raw_input("> ")
            code = list(parse(source))
            code = constant_fold(code)
            Machine(code).run()
        except (RuntimeError, IndexError) as e:
            print("IndexError: %s" % e)
        except KeyboardInterrupt:
            print("\nKeyboardInterrupt")

用一些简单的程序来测试我们的 REPL

> 2 3 + 4 * println
Constant-folded 2+3 to 5
Constant-folded 5*4 to 20
20
> 12 dup * println
144
> "Hello, world!" dup println println
Hello, world!
Hello, world!

你可以看到,常量折叠看起来运转正常。在第一个例子中,它把整个程序优化成这样 20 println

下一步

当你添加完 callreturn 之后,你便可以让使用者定义自己的函数了。在Forth 中函数被称为 words,他们以冒号开头紧接着是名字然后以分号结束。例如,一个整数平方的 word 是长这样滴

: square dup * ;

实际上,你可以试试把这一段放在程序中,比如 Gforth

$ gforth
Gforth 0.7.3, Copyright (C) 1995-2008 Free Software Foundation, Inc.
Gforth comes with ABSOLUTELY NO WARRANTY; for details type `license'
Type `bye' to exit
: square dup * ;  ok
12 square . 144  ok

你可以在解析器中通过发现 : 来支持这一点。一旦你发现一个冒号,你必须记录下它的名字及其地址(比如:在程序中的位置)然后把他们插入到符号表(symbol table)中。简单起见,你甚至可以把整个函数的代码(包括分号)放在字典中,譬如:

symbol_table = {
  "square": ["dup", "*"]
  # ...
}

当你完成了解析的工作,你可以连接你的程序:遍历整个主程序并且在符号表中寻找自定义函数的地方。一旦你找到一个并且它没有在主程序的后面出现,那么你可以把它附加到主程序的后面。然后用 <address> call 替换掉 square,这里的 <address> 是函数插入的地址。

为了保证程序能正常执行,你应该考虑剔除 jmp 操作。否则的话,你不得不解析它们。它确实能执行,但是你得按照用户编写程序的顺序保存它们。举例来说,你想在子例程之间移动,你要格外小心。你可能需要添加 exit 函数用于停止程序(可能需要告诉操作系统返回值),这样主程序就不会继续执行以至于跑到子例程中。

实际上,一个好的程序空间布局很有可能把主程序当成一个名为 main 的子例程。或者由你决定搞成什么样子。

如您所见,这一切都是很有趣的,而且通过这一过程你也学会了很多关于代码生成、链接、程序空间布局相关的知识。

更多能做的事儿

你可以使用 Python 字节码生成库来尝试将虚拟机代码为原生的 Python 字节码。或者用 Java 实现运行在 JVM 上面,这样你就可以自由使用 JITing

同样的,你也可以尝试下register machine。你可以尝试用栈帧(stack frames)实现调用栈(call stack),并基于此建立调用会话。

最后,如果你不喜欢类似 Forth 这样的语言,你可以创造运行于这个虚拟机之上的自定义语言。譬如,你可以把类似 (2+3)*4 这样的中缀表达式转化成 2 3 + 4 * 然后生成代码。你也可以允许 C 风格的代码块 { ... } 这样的话,语句 if ( test ) { ... } else { ... } 将会被翻译成

<true/false test>
<address of true block>
<address of false block>
if
jmp

<true block>
<address of end of entire if-statement> jmp

<false block>
<address of end of entire if-statement> jmp

例子,

Address  Code
-------  ----
 0       2 3 >
 3       7        # Address of true-block
 4       11       # Address of false-block
 5       if
 6       jmp      # Conditional jump based on test

# True-block
 7       "Two is greater than three."
 8       println
 9       15       # Continue main program
10       jmp

# False-block ("else { ... }")
11       "Two is less than three."
12       println
13       15       # Continue main program
14       jmp

# If-statement finished, main program continues here
15       ...

对了,你还需要添加比较操作符 != < <= > >=

我已经在我的 C++ stack machine 实现了这些东东,你可以参考下。

我已经把这里呈现出来的代码搞成了个项目 Crianza,它使用了更多的优化和实验性质的模型来吧程序编译成 Python 字节码。

祝好运!

完整的代码

下面是全部的代码,兼容 Python 2 和 Python 3

你可以通过这里 得到它。

#!/usr/bin/env python
# coding: utf-8

"""
A simple VM interpreter.

Code from the post at http://csl.name/post/vm/
This version should work on both Python 2 and 3.
"""

from __future__ import print_function
from collections import deque
from io import StringIO
import sys
import tokenize


def get_input(*args, **kw):
    """Read a string from standard input."""
    if sys.version[0] == "2":
        return raw_input(*args, **kw)
    else:
        return input(*args, **kw)


class Stack(deque):
    push = deque.append

    def top(self):
        return self[-1]


class Machine:
    def __init__(self, code):
        self.data_stack = Stack()
        self.return_stack = Stack()
        self.instruction_pointer = 0
        self.code = code

    def pop(self):
        return self.data_stack.pop()

    def push(self, value):
        self.data_stack.push(value)

    def top(self):
        return self.data_stack.top()

    def run(self):
        while self.instruction_pointer < len(self.code):
            opcode = self.code[self.instruction_pointer]
            self.instruction_pointer += 1
            self.dispatch(opcode)

    def dispatch(self, op):
        dispatch_map = {
            "%":        self.mod,
            "*":        self.mul,
            "+":        self.plus,
            "-":        self.minus,
            "/":        self.div,
            "==":       self.eq,
            "cast_int": self.cast_int,
            "cast_str": self.cast_str,
            "drop":     self.drop,
            "dup":      self.dup,
            "exit":     self.exit,
            "if":       self.if_stmt,
            "jmp":      self.jmp,
            "over":     self.over,
            "print":    self.print,
            "println":  self.println,
            "read":     self.read,
            "stack":    self.dump_stack,
            "swap":     self.swap,
        }

        if op in dispatch_map:
            dispatch_map[op]()
        elif isinstance(op, int):
            self.push(op) # push numbers on stack
        elif isinstance(op, str) and op[0]==op[-1]=='"':
            self.push(op[1:-1]) # push quoted strings on stack
        else:
            raise RuntimeError("Unknown opcode: '%s'" % op)

    # OPERATIONS FOLLOW:

    def plus(self):
        self.push(self.pop() + self.pop())

    def exit(self):
        sys.exit(0)

    def minus(self):
        last = self.pop()
        self.push(self.pop() - last)

    def mul(self):
        self.push(self.pop() * self.pop())

    def div(self):
        last = self.pop()
        self.push(self.pop() / last)

    def mod(self):
        last = self.pop()
        self.push(self.pop() % last)

    def dup(self):
        self.push(self.top())

    def over(self):
        b = self.pop()
        a = self.pop()
        self.push(a)
        self.push(b)
        self.push(a)

    def drop(self):
        self.pop()

    def swap(self):
        b = self.pop()
        a = self.pop()
        self.push(b)
        self.push(a)

    def print(self):
        sys.stdout.write(str(self.pop()))
        sys.stdout.flush()

    def println(self):
        sys.stdout.write("%s\n" % self.pop())
        sys.stdout.flush()

    def read(self):
        self.push(get_input())

    def cast_int(self):
        self.push(int(self.pop()))

    def cast_str(self):
        self.push(str(self.pop()))

    def eq(self):
        self.push(self.pop() == self.pop())

    def if_stmt(self):
        false_clause = self.pop()
        true_clause = self.pop()
        test = self.pop()
        self.push(true_clause if test else false_clause)

    def jmp(self):
        addr = self.pop()
        if isinstance(addr, int) and 0 <= addr < len(self.code):
            self.instruction_pointer = addr
        else:
            raise RuntimeError("JMP address must be a valid integer.")

    def dump_stack(self):
        print("Data stack (top first):")

        for v in reversed(self.data_stack):
            print(" - type %s, value '%s'" % (type(v), v))


def parse(text):
    # Note that the tokenizer module is intended for parsing Python source
    # code, so if you're going to expand on the parser, you may have to use
    # another tokenizer.

    if sys.version[0] == "2":
        stream = StringIO(unicode(text))
    else:
        stream = StringIO(text)

    tokens = tokenize.generate_tokens(stream.readline)

    for toknum, tokval, _, _, _ in tokens:
        if toknum == tokenize.NUMBER:
            yield int(tokval)
        elif toknum in [tokenize.OP, tokenize.STRING, tokenize.NAME]:
            yield tokval
        elif toknum == tokenize.ENDMARKER:
            break
        else:
            raise RuntimeError("Unknown token %s: '%s'" %
                    (tokenize.tok_name[toknum], tokval))

def constant_fold(code):
    """Constant-folds simple mathematical expressions like 2 3 + to 5."""
    while True:
        # Find two consecutive numbers and an arithmetic operator
        for i, (a, b, op) in enumerate(zip(code, code[1:], code[2:])):
            if isinstance(a, int) and isinstance(b, int) \
                    and op in {"+", "-", "*", "/"}:
                m = Machine((a, b, op))
                m.run()
                code[i:i+3] = [m.top()]
                print("Constant-folded %s%s%s to %s" % (a,op,b,m.top()))
                break
        else:
            break
    return code

def repl():
    print('Hit CTRL+D or type "exit" to quit.')

    while True:
        try:
            source = get_input("> ")
            code = list(parse(source))
            code = constant_fold(code)
            Machine(code).run()
        except (RuntimeError, IndexError) as e:
            print("IndexError: %s" % e)
        except KeyboardInterrupt:
            print("\nKeyboardInterrupt")

def test(code = [2, 3, "+", 5, "*", "println"]):
    print("Code before optimization: %s" % str(code))
    optimized = constant_fold(code)
    print("Code after optimization: %s" % str(optimized))

    print("Stack after running original program:")
    a = Machine(code)
    a.run()
    a.dump_stack()

    print("Stack after running optimized program:")
    b = Machine(optimized)
    b.run()
    b.dump_stack()

    result = a.data_stack == b.data_stack
    print("Result: %s" % ("OK" if result else "FAIL"))
    return result

def examples():
    print("** Program 1: Runs the code for `print((2+3)*4)`")
    Machine([2, 3, "+", 4, "*", "println"]).run()

    print("\n** Program 2: Ask for numbers, computes sum and product.")
    Machine([
        '"Enter a number: "', "print", "read", "cast_int",
        '"Enter another number: "', "print", "read", "cast_int",
        "over", "over",
        '"Their sum is: "', "print", "+", "println",
        '"Their product is: "', "print", "*", "println"
    ]).run()

    print("\n** Program 3: Shows branching and looping (use CTRL+D to exit).")
    Machine([
        '"Enter a number: "', "print", "read", "cast_int",
        '"The number "', "print", "dup", "print", '" is "', "print",
        2, "%", 0, "==", '"even."', '"odd."', "if", "println",
        0, "jmp" # loop forever!
    ]).run()


if __name__ == "__main__":
    try:
        if len(sys.argv) > 1:
            cmd = sys.argv[1]
            if cmd == "repl":
                repl()
            elif cmd == "test":
                test()
                examples()
            else:
                print("Commands: repl, test")
        else:
            repl()
    except EOFError:
        print("")