Skip to main content

ast

ast 模块用于将 Python 源码解析为抽象语法树(Abstract Syntax Tree),是理解 CPython 编译流程的核心工具。源码经过 tokenizeparse(生成 AST)→ compile(生成字节码)→ 执行,AST 正是其中承上启下的关键环节。

ast

解析与查看 AST

import ast

source = """
def add(a, b):
return a + b
"""

tree = ast.parse(source)
print(ast.dump(tree, indent=2))
"""
Module(
body=[
FunctionDef(
name='add',
args=arguments(
posonlyargs=[],
args=[
arg(arg='a'),
arg(arg='b')],
kwonlyargs=[],
kw_defaults=[],
defaults=[]),
body=[
Return(
value=BinOp(
left=Name(id='a', ctx=Load()),
op=Add(),
right=Name(id='b', ctx=Load())))],
decorator_list=[])],
type_ignores=[])
"""

遍历 AST 节点

ast.NodeVisitor 用于只读遍历,ast.NodeTransformer 用于修改 AST。

import ast

source = """
x = 1
y = x + 2
print(y)
z = x * y + 3
"""

class NameCollector(ast.NodeVisitor):
"""收集所有变量名"""
def __init__(self):
self.names = set()

def visit_Name(self, node):
self.names.add(node.id)
self.generic_visit(node)

tree = ast.parse(source)
collector = NameCollector()
collector.visit(tree)
print(collector.names)
# {'x', 'y', 'print', 'z'}

修改 AST

import ast

class ConstantDoubler(ast.NodeTransformer):
"""将所有整数字面量翻倍"""
def visit_Constant(self, node):
if isinstance(node.value, int):
node.value *= 2
return node

source = "result = 1 + 2 + 3"
tree = ast.parse(source)
tree = ConstantDoubler().visit(tree)
ast.fix_missing_locations(tree)

code = compile(tree, "<string>", "exec")
namespace = {}
exec(code, namespace)
print(namespace["result"]) # 12 (= 2 + 4 + 6)

安全地求值表达式

ast.literal_eval 只允许基本字面量,比 eval() 安全得多。

import ast

print(ast.literal_eval("[1, 2, 3]")) # [1, 2, 3]
print(ast.literal_eval("{'a': 1, 'b': 2}")) # {'a': 1, 'b': 2}
print(ast.literal_eval("(True, None, 3.14)")) # (True, None, 3.14)

try:
ast.literal_eval("__import__('os').system('echo hacked')")
except (ValueError, SyntaxError) as e:
print(f"被拒绝: {e}")

代码分析:函数复杂度统计

import ast

source = """
def simple():
return 1

def complex_func(x):
if x > 0:
for i in range(x):
if i % 2 == 0:
yield i
elif x < 0:
raise ValueError("negative")
else:
return 0
"""

class ComplexityAnalyzer(ast.NodeVisitor):
"""统计函数中的分支与循环数量"""
def __init__(self):
self.results = {}
self._current_func = None
self._complexity = 0

def visit_FunctionDef(self, node):
old_func, old_complexity = self._current_func, self._complexity
self._current_func = node.name
self._complexity = 1
self.generic_visit(node)
self.results[node.name] = self._complexity
self._current_func, self._complexity = old_func, old_complexity

def visit_If(self, node):
self._complexity += 1
self.generic_visit(node)

def visit_For(self, node):
self._complexity += 1
self.generic_visit(node)

visit_While = visit_For

analyzer = ComplexityAnalyzer()
analyzer.visit(ast.parse(source))
for name, score in analyzer.results.items():
print(f"{name}: 圈复杂度 = {score}")
# simple: 圈复杂度 = 1
# complex_func: 圈复杂度 = 4

编译与执行

import ast

source = "print('Hello from AST!')"
tree = ast.parse(source, mode="exec")
code = compile(tree, filename="<ast>", mode="exec")
exec(code)
# Hello from AST!

expr_tree = ast.parse("2 ** 10", mode="eval")
result = eval(compile(expr_tree, "<ast>", "eval"))
print(result) # 1024
AST 的实际应用
  • 代码格式化工具(如 Black)基于 AST 确保格式化不改变语义
  • 静态分析工具(如 Pylint、mypy)使用 AST 检查代码问题
  • AI 代码生成:分析/验证 LLM 生成代码的语法正确性
  • CPython 贡献:理解 AST 是参与编译器优化的前提