Skip to main content

生成器和迭代器

info

生成器让Python能够以内存高效的方式处理大量数据。它们实现了惰性求值,只在需要时计算值,体现了Python对性能和内存使用的关注。

def fibonacci():
a, b = 0, 1
while True:
yield a
a, b = b, a + b

PEP 255 – 简单生成器

迭代器

迭代是 Python 最强大的功能之一,是访问集合元素的一种方式。

迭代器是一个可以记住遍历的位置的对象。

迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。

迭代器有两个基本的方法:iter() 和 next()。

字符串,列表或元组对象都可用于创建迭代器:

list=[1,2,3,4]
it = iter(list) # 创建迭代器对象
next(it) # 输出迭代器的下一个元素
next(it) # 再输出下一个元素

enumerate

列表好处是不需要对下标进行迭代,直接输出列表的值:

x = [2, 4, 6]

for i in x:
print(i)

但是有些情况下,我们既希望获得下标, 也希望获得对应的值,那么:

可以将迭代器传给 enumerate 函数, 这样每次迭代都会返回一组 (index, value) 组成的元组:

x = [2, 4, 6]
for i, n in enumerate(x):
print(i, 'is', n)

自定义迭代器

一个迭代器都有 __iter__()__next__()

__iter__() 方法返回一个特殊的迭代器对象, 这个迭代器对象实现了 __next__() 方法并通过 StopIteration 异常标识迭代的完成。

__next__() 方法(Python 2 里是 next())会返回下一个迭代器对象。

自定义一个 list 的取反迭代器:

class ReverseListIterator(object):
def __init__(self, lst):
self.list = lst
self.index = len(lst)

def __iter__(self):
return self

def __next__(self):
self.index -= 1
if self.index >= 0:
return self.list[self.index]
else:
raise StopIteration
x = range(10)
for i in ReverseListIterator(x):
print(i)

只要我们定义了这三个方法(__init__, __iter__, __next__),我们可以返回任意迭代值:

实现 Collatz 猜想

这里我们实现 Collatz 猜想:

  • 奇数 n:返回 3n + 1
  • 偶数 n:返回 n / 2
  • 直到 n 为 1 为止:
class Collatz(object):
def __init__(self, start):
self.value = start

def __iter__(self):
return self

def __next__(self):
if self.value == 1:
raise StopIteration
elif self.value % 2 == 0:
self.value = self.value / 2
else:
self.value = 3 * self.value + 1
return self.value


for x in Collatz(5):
print(x)

不过迭代器对象存在状态,有问题

i = Collatz(5)
# zip() 函数用于将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的迭代器。
for x, y in zip(i, i):
print(x, y)

# 下方代码等价于上方代码
i = Collatz(5)
# *zipped 可理解为解压,返回二维矩阵式
zipped = zip(i, i)
#<zip object at 0x00000200CFC1F400> #返回的是一个对象
x, y = zip(*zipped)
print(x, y)

解决方法是将迭代器和可迭代对象分开处理。

迭代器和可迭代对象分开处理

这里提供了一个二分树的中序遍历实现:

class BinaryTree(object):
def __init__(self, value, left=None, right=None):
self.value = value
self.left = left
self.right = right

def __iter__(self):
return InorderIterator(self)

class InorderIterator(object):
def __init__(self, node):
self.node = node
self.stack = []

def __next__(self):
if len(self.stack) > 0 or self.node is not None:
while self.node is not None:
self.stack.append(self.node)
self.node = self.node.left
node = self.stack.pop()
self.node = node.right
return node.value
else:
raise StopIteration()

测试:

tree = BinaryTree(
left=BinaryTree(
left=BinaryTree(1),
value=2,
right=BinaryTree(
left=BinaryTree(3),
value=4,
right=BinaryTree(5)
),
),
value=6,
right=BinaryTree(
value=7,
right=BinaryTree(8)
)
)
for value in tree:
print(value)

不会出现之前的问题:


for x, y in zip(tree, tree):
print(x, y)

生成器

在 Python 中,使用了 yield 的函数被称为生成器(generator)。

跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。

  1. 迭代器则通过 next 的 return 将值返回;
  2. 与迭代器不同的是,生成器会自动记录当前的状态, 而迭代器则需要进行额外的操作来记录当前的状态。

之前的 collatz 猜想,简单循环的实现如下:

collatz:

  • 奇数 n:返回 3n + 1
  • 偶数 n:返回 n / 2
  • 直到 n 为 1 为止:
def collatz(n):
sequence = []
while n != 1:
if n % 2 == 0:
n /= 2
else:
n = 3 * n + 1
sequence.append(n)
return sequence


for x in collatz(5):
print(x)

生成器的版本如下:

def collatz(n):
while n != 1:
if n % 2 == 0:
n /= 2
else:
n = 3 * n + 1
yield n


for x in collatz(5):
print(x)

迭代器的版本如下:

class Collatz(object):
def __init__(self, start):
self.value = start

def __iter__(self):
return self

def next(self):
if self.value == 1:
raise StopIteration
elif self.value % 2 == 0:
self.value = self.value / 2
else:
self.value = 3 * self.value + 1
return self.value

for x in collatz(5):
print(x)

事实上,生成器也是一种迭代器:

x = collatz(5)
x

它支持 next 方法,返回下一个 yield 的值:

next(x)
next(x)

__iter__ 方法返回的是它本身:

x.__iter__()

return 和 yield 有什么区别?

yield 是暂停的意思(它有程序中起着类似红绿灯中等红灯的作用);yield 是创建迭代器,可以用 for 来遍历,有点事件触发的意思

return 在方法中直接返回值;是函数返回值,当执行到 return,后续的逻辑代码不在执行

相同点: 都是定义函数过程中返回值

不同点:yield 是暂停函数,return 是结束函数; 即 yield 返回值后继续执行函数体内代码,return 返回值后不再执行函数体内代码。

yield 返回的是一个迭代器(yield 本身是生成器-生成器是用来生成迭代器的);return 返回的是正常可迭代对象(list,set,dict 等具有实际内存地址的存储对象)

如果要返回的数据是通过 for 等循环生成的迭代器类型数据(如列表、元组),return 只能在循环外部一次性地返回,yeild 则可以在循环内部逐个元素返回。

yiled from 还可以使一个生成器可以委派子生成器,建立双向通道


def g1(x):
yield range(x, 0, -1)
yield range(x)
print(list(g1(5)))
#[range(5, 0, -1), range(0, 5)]

def g2(x):
yield from range(x, 0, -1)
yield from range(x)
print(list(g2(5)))
#[5, 4, 3, 2, 1, 0, 1, 2, 3, 4]

迭代器和生成器有什么区别?

在 Python 中,使用了 yield 的函数被称为生成器(generator)。跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。在调用生成器运行的过程中,每次遇到 yield 时函数会暂停并保存当前所有的运行信息,返回 yield 的值, 并在下一次执行 next() 方法时从当前位置继续运行。

调用一个生成器函数,返回的是一个迭代器对象:迭代是 Python 最强大的功能之一,是访问集合元素的一种方式。迭代器是一个可以记住遍历的位置的对象。迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器只能往前不会后退。迭代器有两个基本的方法:iter() 和 next()

newinit的区别?

执行顺序的不同:只有在new返回一个 cls 的实例时后面的init才能被调用

功能上的不同:当创建一个新实例时调用new,初始化一个实例时用init

返回值的不同:new方法会返回一个创建的实例,而init什么都不返回

推导式

推导式(Comprehensions)是 Python 中一种简洁、高效的创建数据结构的语法糖。它能够用一行代码完成原本需要多行循环才能实现的功能,不仅代码更加简洁,执行效率也更高。

列表推导式

列表推导式的基本语法结构如下:

[expression for item in iterable]

这相当于:

result = []
for item in iterable:
result.append(expression)

基础示例

# 创建平方数列表
squares = [x**2 for x in range(10)]
print(squares) # [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

# 字符串转大写
words = ['hello', 'world', 'python']
upper_words = [word.upper() for word in words]
print(upper_words) # ['HELLO', 'WORLD', 'PYTHON']

# 提取数字的个位数
numbers = [123, 456, 789]
last_digits = [num % 10 for num in numbers]
print(last_digits) # [3, 6, 9]

带条件的列表推导式

条件过滤(if 语句)

# 基本语法
[expression for item in iterable if condition]

# 示例:筛选偶数
numbers = range(20)
even_numbers = [x for x in numbers if x % 2 == 0]
print(even_numbers) # [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]

# 筛选正数并求平方
numbers = [-3, -2, -1, 0, 1, 2, 3]
positive_squares = [x**2 for x in numbers if x > 0]
print(positive_squares) # [1, 4, 9]

# 筛选特定长度的字符串
words = ['cat', 'dog', 'elephant', 'bird', 'python']
long_words = [word for word in words if len(word) > 4]
print(long_words) # ['elephant', 'python']

条件表达式(三元运算符)

# 语法:expression1 if condition else expression2
[expression1 if condition else expression2 for item in iterable]

# 示例:将负数转为0,正数保持不变
numbers = [-3, -1, 0, 2, 5]
processed = [x if x >= 0 else 0 for x in numbers]
print(processed) # [0, 0, 0, 2, 5]

# 根据条件设置不同的值
scores = [85, 92, 78, 96, 73]
grades = ['A' if score >= 90 else 'B' if score >= 80 else 'C' for score in scores]
print(grades) # ['B', 'A', 'C', 'A', 'C']

嵌套循环

# 基本语法
[expression for item1 in iterable1 for item2 in iterable2]

# 示例:生成坐标点
coordinates = [(x, y) for x in range(3) for y in range(3)]
print(coordinates)
# [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]

# 矩阵展平
matrix = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
flattened = [num for row in matrix for num in row]
print(flattened) # [1, 2, 3, 4, 5, 6, 7, 8, 9]

# 带条件的嵌套循环
result = [(x, y) for x in range(5) for y in range(5) if x + y == 4]
print(result) # [(0, 4), (1, 3), (2, 2), (3, 1), (4, 0)]

复杂示例

# 处理字符串列表
sentences = ['Hello World', 'Python Programming', 'Data Science']
word_lengths = [[len(word) for word in sentence.split()] for sentence in sentences]
print(word_lengths) # [[5, 5], [6, 11], [4, 7]]

# 过滤和转换文件名
filenames = ['data.txt', 'image.jpg', 'script.py', 'document.pdf', 'code.py']
python_files = [filename.upper() for filename in filenames if filename.endswith('.py')]
print(python_files) # ['SCRIPT.PY', 'CODE.PY']

# 处理嵌套数据结构
students = [
{'name': 'Alice', 'scores': [85, 90, 88]},
{'name': 'Bob', 'scores': [78, 85, 92]},
{'name': 'Charlie', 'scores': [92, 88, 95]}
]
averages = [{'name': student['name'], 'average': sum(student['scores'])/len(student['scores'])}
for student in students]
print(averages)
# [{'name': 'Alice', 'average': 87.67}, {'name': 'Bob', 'average': 85.0}, {'name': 'Charlie', 'average': 91.67}]

字典推导式

基本语法

{key_expression: value_expression for item in iterable}

基础示例

# 创建平方数字典
squares_dict = {x: x**2 for x in range(5)}
print(squares_dict) # {0: 0, 1: 1, 2: 4, 3: 9, 4: 16}

# 字符串长度字典
words = ['apple', 'banana', 'cherry']
word_lengths = {word: len(word) for word in words}
print(word_lengths) # {'apple': 5, 'banana': 6, 'cherry': 6}

# 反转字典
original = {'a': 1, 'b': 2, 'c': 3}
reversed_dict = {v: k for k, v in original.items()}
print(reversed_dict) # {1: 'a', 2: 'b', 3: 'c'}

带条件的字典推导式

# 筛选偶数键值对
numbers = range(10)
even_squares = {x: x**2 for x in numbers if x % 2 == 0}
print(even_squares) # {0: 0, 2: 4, 4: 16, 6: 36, 8: 64}

# 过滤字典
scores = {'Alice': 85, 'Bob': 92, 'Charlie': 78, 'Diana': 96}
high_scores = {name: score for name, score in scores.items() if score >= 90}
print(high_scores) # {'Bob': 92, 'Diana': 96}

# 条件值设置
numbers = range(-3, 4)
abs_dict = {x: abs(x) if x < 0 else x for x in numbers}
print(abs_dict) # {-3: 3, -2: 2, -1: 1, 0: 0, 1: 1, 2: 2, 3: 3}

复杂示例

# 统计字符频率
text = "hello world"
char_count = {char: text.count(char) for char in set(text) if char != ' '}
print(char_count) # {'e': 1, 'h': 1, 'l': 3, 'o': 2, 'r': 1, 'd': 1, 'w': 1}

# 分组数据
students = ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve']
grouped = {len(name): [n for n in students if len(n) == len(name)] for name in students}
# 去重
grouped = {length: list(set(names)) for length, names in grouped.items()}
print(grouped) # {5: ['Alice', 'Diana'], 3: ['Bob', 'Eve'], 7: ['Charlie']}

# 嵌套字典处理
data = [
{'name': 'Alice', 'age': 25, 'city': 'New York'},
{'name': 'Bob', 'age': 30, 'city': 'London'},
{'name': 'Charlie', 'age': 35, 'city': 'Tokyo'}
]
name_to_info = {person['name']: {k: v for k, v in person.items() if k != 'name'}
for person in data}
print(name_to_info)
# {'Alice': {'age': 25, 'city': 'New York'}, 'Bob': {'age': 30, 'city': 'London'}, 'Charlie': {'age': 35, 'city': 'Tokyo'}}

集合推导式

基本语法

{expression for item in iterable}

示例

# 创建平方数集合
squares_set = {x**2 for x in range(10)}
print(squares_set) # {0, 1, 4, 9, 16, 25, 36, 49, 64, 81}

# 提取唯一字符
text = "hello world"
unique_chars = {char.upper() for char in text if char != ' '}
print(unique_chars) # {'H', 'E', 'L', 'O', 'W', 'R', 'D'}

# 过滤重复值
numbers = [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
unique_evens = {x for x in numbers if x % 2 == 0}
print(unique_evens) # {2, 4}

生成器表达式

基本语法

(expression for item in iterable)

生成器表达式与列表推导式类似,但使用圆括号,返回生成器对象而不是列表,具有内存效率优势。

示例

# 创建生成器
squares_gen = (x**2 for x in range(10))
print(type(squares_gen)) # <class 'generator'>

# 迭代生成器
for square in squares_gen:
print(square, end=' ') # 0 1 4 9 16 25 36 49 64 81

# 内存效率对比
import sys

# 列表推导式
list_comp = [x**2 for x in range(1000)]
print(f"列表大小: {sys.getsizeof(list_comp)} bytes")

# 生成器表达式
gen_exp = (x**2 for x in range(1000))
print(f"生成器大小: {sys.getsizeof(gen_exp)} bytes")

# 转换为其他数据结构
numbers = (x for x in range(5))
numbers_list = list(numbers) # [0, 1, 2, 3, 4]

基础示例

# 基本用法
age = 20
status = "成年人" if age >= 18 else "未成年人"
print(status) # 成年人

# 数值处理
x = -5
abs_x = x if x >= 0 else -x
print(abs_x) # 5

# 字符串处理
name = ""
display_name = name if name else "匿名用户"
print(display_name) # 匿名用户

嵌套三元表达式

# 多重条件
score = 85
grade = "A" if score >= 90 else "B" if score >= 80 else "C" if score >= 70 else "D"
print(grade) # B

# 更复杂的嵌套
temperature = 25
weather = "热" if temperature > 30 else "温暖" if temperature > 20 else "凉爽" if temperature > 10 else "寒冷"
print(weather) # 温暖

在函数中的应用

def get_max(a, b):
return a if a > b else b

def safe_divide(a, b):
return a / b if b != 0 else 0

def format_number(num):
return f"{num:,.2f}" if isinstance(num, (int, float)) else "无效数字"

print(get_max(10, 20)) # 20
print(safe_divide(10, 0)) # 0
print(format_number(1234.567)) # 1,234.57

常见错误与解决方案

1. 变量作用域问题

# 错误示例
functions = []
for i in range(5):
functions.append(lambda: i) # 所有lambda都引用同一个i
# 或 functions = [lambda : x for i in range(5)]

# 问题:所有函数都返回4
print([f() for f in functions]) # [4, 4, 4, 4, 4]

# 方法1:使用默认参数捕获值(简单,但是不推荐)
fs1 = [lambda i=i: i for i in range(3)]
print("方法1:", [f() for f in fs1])

# 方法2:使用闭包函数(推荐)
def make_function(x):
return lambda: x
fs2 = [make_function(i) for i in range(3)]
print("方法2:", [f() for f in fs2])

# 方法3:使用functools.partial
from functools import partial
fs3 = [partial(lambda x: x, i) for i in range(3)]
print("方法3:", [f() for f in fs3])

2. 过度嵌套

# 不好的做法:过度嵌套
result = [[[y**2 for y in x if y > 0] for x in row if len(x) > 2] for row in matrix if row]

# 好的做法:分步处理
filtered_matrix = [row for row in matrix if row]
filtered_rows = [x for row in filtered_matrix for x in row if len(x) > 2]
result = [[y**2 for y in x if y > 0] for x in filtered_rows]

3. 内存使用不当

# 大数据集时避免使用列表推导式
# 不好:一次性创建大列表
big_list = [expensive_operation(x) for x in range(1000000)]

# 好:使用生成器表达式
big_generator = (expensive_operation(x) for x in range(1000000))
for item in big_generator:
process(item) # 逐个处理,节省内存

内置函数

iter函数、next函数