Abstract Syntax Trees 即抽象语法树. Ast 是 python 源码到字节码的一种中间产物, 借助 ast 模块可以从语法树的角度分析源码结构. 此外, 我们不仅可以修改和执行语法树, 还可以将 Source 生成的语法树 unparse 成 python 源码. 因此 ast 给 python 源码检查, 语法分析, 修改代码以及代码调试等留下了足够的发挥空间.
1. AST 简介
Python 官方提供的 CPython 解释器对 python 源码的处理过程如下:
- Parse source code into a parse tree (Parser/pgen.c)
- Transform parse tree into an Abstract Syntax Tree (Python/ast.c)
- Transform AST into a Control Flow Graph (Python/compile.c)
- Emit bytecode based on the Control Flow Graph (Python/compile.c)
即实际 python 代码的处理过程如下:
源代码解析 --> 语法树 --> 抽象语法树 (AST) --> 控制流程图 --> 字节码
上述过程在 python2.5 之后被应用. python 源码首先被解析成语法树, 随后又转换成抽象语法树. 在抽象语法树中我们可以看到源码文件中的 python 的语法结构.
大部分时间编程可能都不需要用到抽象语法树, 但是在特定的条件和需求的情况下, AST 又有其特殊的方便性.
下面是一个抽象语法的简单实例.
- Module(body=[
- Print(
- dest=None,
- values=[BinOp( left=Num(n=1),op=Add(),right=Num(n=2))],
- nl=True,
- )])
2. 创建 AST
2.1 Compile 函数
先简单了解一下 compile 函数.
compile(source, filename, mode[, flags[, dont_inherit]])
source -- 字符串或者 AST(Abstract Syntax Trees) 对象. 一般可将整个 py 文件内容 file.read() 传入.
filename -- 代码文件名称, 如果不是从文件读取代码则传递一些可辨认的值.
mode -- 指定编译代码的种类. 可以指定为 exec, eval, single.
flags -- 变量作用域, 局部命名空间, 如果被提供, 可以是任何映射对象.
flags 和 dont_inherit 是用来控制编译源码时的标志.
- func_def = \
- """
- def add(x, y):
- return x + y
- print add(3, 5)
- """
使用 Compile 编译并执行:
- >>> cm = compile(func_def, '<string>', 'exec')
- >>> exec cm
- >>> 8
上面 func_def 经过 compile 编译得到字节码, cm 即 code 对象, True == isinstance(cm, types.CodeType).
compile(source, filename, mode, ast.PyCF_ONLY_AST) <==> ast.parse(source, filename='<unknown>', mode='exec')
2.2 生成 ast
使用上面的 func_def 生成 ast.
- r_node = ast.parse(func_def)
- print astunparse.dump(r_node) # print ast.dump(r_node)
下面是 func_def 对应的 ast 结构:
- Module(body=[
- FunctionDef(
- name='add',
- args=arguments(
- args=[Name(id='x',ctx=Param()),Name(id='y',ctx=Param())],
- vararg=None,
- kwarg=None,
- defaults=[]),
- body=[Return(value=BinOp(
- left=Name(id='x',ctx=Load()),
- op=Add(),
- right=Name(id='y',ctx=Load())))],
- decorator_list=[]),
- Print(
- dest=None,
- values=[Call(
- func=Name(id='add',ctx=Load()),
- args=[Num(n=3),Num(n=5)],
- keywords=[],
- starargs=None,
- kwargs=None)],
- nl=True)
- ])
除了 ast.dump, 有很多 dump ast 的第三方库, 如 astunparse, codegen, unparse 等. 这些第三方库不仅能够以更好的方式展示出 ast 结构, 还能够将 ast 反向导出 python source 代码.
- module Python version "$Revision$"
- {
- mod = Module(stmt* body)| Expression(expr body)
- stmt = FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list)
- | ClassDef(identifier name, expr* bases, stmt* body, expr* decorator_list)
- | Return(expr? value)
- | Print(expr? dest, expr* values, bool nl)| For(expr target, expr iter, stmt* body, stmt* orelse)
- expr = BoolOp(boolop op, expr* values)
- | BinOp(expr left, operator op, expr right)| Lambda(arguments args, expr body)| Dict(expr* keys, expr* values)| Num(object n) -- a number as a PyObject.
- | Str(string s) -- need to specify raw, unicode, etc?| Name(identifier id, expr_context ctx)
- | List(expr* elts, expr_context ctx)
- -- col_offset is the byte offset in the utf8 string the parser uses
- attributes (int lineno, int col_offset)
- expr_context = Load | Store | Del | AugLoad | AugStore | Param
- boolop = And | Or
- operator = Add | Sub | Mult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv
- arguments = (expr* args, identifier? vararg, identifier? kwarg, expr* defaults)
- }
- View Code
上面是部分摘自官网的 Abstract Grammar, 实际遍历 ast Node 过程中根据 Node 的类型访问其属性.
3. 遍历 AST
python 提供了两种方式来遍历整个抽象语法树.
3.1 ast.NodeTransfer
将 func_def 中的 add 函数中的加法运算改为减法, 同时为函数实现添加调用日志.
- class CodeVisitor(ast.NodeVisitor):
- def visit_BinOp(self, node):
- if isinstance(node.op, ast.Add):
- node.op = ast.Sub()
- self.generic_visit(node)
- def visit_FunctionDef(self, node):
- print 'Function Name:%s'% node.name
- self.generic_visit(node)
- func_log_stmt = ast.Print(
- dest = None,
- values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
- nl = True,
- lineno = 0,
- col_offset = 0,
- )
- node.body.insert(0, func_log_stmt)
- r_node = ast.parse(func_def)
- visitor = CodeVisitor()
- visitor.visit(r_node)
- # print astunparse.dump(r_node)
- print astunparse.unparse(r_node)
- exec compile(r_node, '<string>', 'exec')
运行结果:
- Function Name:add
- def add(x, y):
- print 'calling func: add'
- return (x - y)
- print add(3, 5)
- calling func: add
- -2
- 3.2 ast.NodeTransfer
使用 NodeVisitor 主要是通过修改语法树上节点的方式改变 AST 结构, NodeTransformer 主要是替换 ast 中的节点.
既然 func_def 中定义的 add 已经被改成一个减函数了, 那么我们就彻底一点, 把函数名和参数以及被调用的函数都在 ast 中改掉, 并且将添加的函数调用 log 写的更加复杂一些, 争取改的面目全非:-)
- class CodeTransformer(ast.NodeTransformer):
- def visit_BinOp(self, node):
- if isinstance(node.op, ast.Add):
- node.op = ast.Sub()
- self.generic_visit(node)
- return node
- def visit_FunctionDef(self, node):
- self.generic_visit(node)
- if node.name == 'add':
- node.name = 'sub'
- args_num = len(node.args.args)
- args = tuple([arg.id for arg in node.args.args])
- func_log_stmt = ''.join(["print 'calling func: %s', "% node.name,"'args:'",", %s" * args_num % args])
- node.body.insert(0, ast.parse(func_log_stmt))
- return node
- def visit_Name(self, node):
- replace = {'add': 'sub', 'x': 'a', 'y': 'b'}
- re_id = replace.get(node.id, None)
- node.id = re_id or node.id
- self.generic_visit(node)
- return node
- r_node = ast.parse(func_def)
- transformer = CodeTransformer()
- r_node = transformer.visit(r_node)
- # print astunparse.dump(r_node)
- source = astunparse.unparse(r_node)
- print source
- # exec compile(r_node, '<string>', 'exec') # 新加入的 node func_log_stmt 缺少 lineno 和 col_offset 属性
- exec compile(source, '<string>', 'exec')
- exec compile(ast.parse(source), '<string>', 'exec')
结果:
- def sub(a, b):
- print 'calling func: sub', 'args:', a, b
- return (a - b)
- print sub(3, 5)
- calling func: sub args: 3 5
- -2
- calling func: sub args: 3 5
- -2
代码中能够清楚的看到两者的区别. 这里不再赘述.
4.AST 应用
AST 模块实际编程中很少用到, 但是作为一种源代码辅助检查手段是非常有意义的; 语法检查, 调试错误, 特殊字段检测等.
上面通过为函数添加调用日志的信息是一种调试 python 源代码的一种方式, 不过实际中我们是通过 parse 整个 python 文件的方式遍历修改源码.
4.1 汉字检测
下面是中日韩字符的 unicode 编码范围
- CJK Unified Ideographs
- Range: 4E00- 9FFF
- Number of characters: 20992
Languages: chinese, japanese, korean, vietnamese
使用 unicode 范围 \u4e00 - \u9fff 来判别汉字, 注意这个范围并不包含中文字符 (e.g. u';' == u'\uff1b') .
下面是一个判断字符串中是否包含中文字符的一个类 CNCheckHelper:
- class CNCheckHelper(object):
- # 待检测文本可能的编码方式列表
- VALID_ENCODING = ('utf-8', 'gbk')
- def _get_unicode_imp(self, value, idx = 0):
- if idx < len(self.VALID_ENCODING):
- try:
- return value.decode(self.VALID_ENCODING[idx])
- except:
- return self._get_unicode_imp(value, idx + 1)
- def _get_unicode(self, from_str):
- if isinstance(from_str, unicode):
- return None
- return self._get_unicode_imp(from_str)
- def is_any_chinese(self, check_str, is_strict = True):
- unicode_str = self._get_unicode(check_str)
- if unicode_str:
- c_func = any if is_strict else all
- return c_func(u'\u4e00' <= char <= u'\u9fff' for char in unicode_str)
- return False
接口 is_any_chinese 有两种判断模式, 严格检测只要包含中文字符串就可以检查出, 非严格必须全部包含中文.
下面我们利用 ast 来遍历源文件的抽象语法树, 并检测其中字符串是否包含中文字符.
- class CodeCheck(ast.NodeVisitor):
- def __init__(self):
- self.cn_checker = CNCheckHelper()
- def visit_Str(self, node):
- self.generic_visit(node)
- # if node.s and any(u'\u4e00' <= char <= u'\u9fff' for char in node.s.decode('utf-8')):
- if self.cn_checker.is_any_chinese(node.s, True):
- print 'line no: %d, column offset: %d, CN_Str: %s' % (node.lineno, node.col_offset, node.s)
- project_dir = './your_project/script'
- for root, dirs, files in os.walk(project_dir):
- print root, dirs, files
- py_files = filter(lambda file: file.endswith('.py'), files)
- checker = CodeCheck()
- for file in py_files:
- file_path = os.path.join(root, file)
- print 'Checking: %s' % file_path
- with open(file_path, 'r') as f:
- root_node = ast.parse(f.read())
- checker.visit(root_node)
上面这个例子比较的简单, 但大概就是这个意思.
关于 CPython 解释器执行源码的过程可以参考官网描述: PEP 339 https://www.python.org/dev/peps/pep-0339/
4.2 Closure 检查
一个函数中定义的函数或者 lambda 中引用了父函数中的 local variable, 并且当做返回值返回. 特定场景下闭包是非常有用的, 但是也很容易被误用.
关于 python 闭包的概念可以参考我的另一篇文章: 理解 Python 闭包概念
这里简单介绍一下如何借助 ast 来检测 lambda 中闭包的引用. 代码如下:
- class LambdaCheck(ast.NodeVisitor):
- def __init__(self):
- self.illegal_args_list = []
- self._cur_file = None
- self._cur_lambda_args = []
- def set_cur_file(self, cur_file):
- assert os.path.isfile(cur_file), cur_file
- self._cur_file = os.path.realpath(cur_file)
- def visit_Lambda(self, node):
- """
- lambda 闭包检查原则:
- 只需检测 lambda expr body 中 args 是否引用了 lambda args list 之外的参数
- """
- self._cur_lambda_args =[a.id for a in node.args.args]
- print astunparse.unparse(node)
- # print astunparse.dump(node)
- self.get_lambda_body_args(node.body)
- self.generic_visit(node)
- def record_args(self, name_node):
- if isinstance(name_node, ast.Name) and name_node.id not in self._cur_lambda_args:
- self.illegal_args_list.append((self._cur_file, 'line no:%s' % name_node.lineno, 'var:%s' % name_node.id))
- def _is_args(self, node):
- if isinstance(node, ast.Name):
- self.record_args(node)
- return True
- if isinstance(node, ast.Call):
- map(self.record_args, node.args)
- return True
- return False
- def get_lambda_body_args(self, node):
- if self._is_args(node): return
- # for cnode in ast.walk(node):
- for cnode in ast.iter_child_nodes(node):
- if not self._is_args(cnode):
- self.get_lambda_body_args(cnode)
遍历工程文件:
- project_dir = './your project/script'
- for root, dirs, files in os.walk(project_dir):
- py_files = filter(lambda file: file.endswith('.py'), files)
- checker = LambdaCheck()
- for file in py_files:
- file_path = os.path.join(root, file)
- checker.set_cur_file(file_path)
- with open(file_path, 'r') as f:
- root_node = ast.parse(f.read())
- checker.visit(root_node)
- res = '\n'.join(['##'.join(info) for info in checker.illegal_args_list])
- print res
- View Code
由于 Lambda(arguments args, expr body) 中的 body expression 可能非常复杂, 上面的例子中仅仅处理了比较简单的 body expr. 可根据自己工程特点修改和扩展检查规则. 为了更加一般化可以单独写一个 visitor 类来遍历 lambda 节点.
Ast 的应用不仅限于上面的例子, 限于篇幅, 先介绍到这里. 期待 ast 能帮助你解决一些比较棘手的问题.
来源: https://www.cnblogs.com/yssjun/p/10069199.html