
本教程详细介绍了如何利用numpy的`np.lib.stride_tricks.as_strided`函数高效地对二维数组进行2x2块的修改。文章通过创建数组的“块视图”并结合查找表(lut)机制,避免了传统python循环的性能瓶颈。内容涵盖了多维索引和扁平化索引两种lut构建方法,并提供了详细的代码示例与注意事项,旨在帮助读者掌握numpy高级技巧,优化大规模数组的块级操作性能。
在数据处理和科学计算中,我们经常需要对二维数组的局部区域,特别是固定大小的块(例如2x2)进行遍历和修改。传统的Python循环虽然直观,但在处理大型NumPy数组时效率低下,因为它无法充分利用NumPy底层C语言实现的优化。为了克服这一性能瓶颈,NumPy提供了一系列高级工具,其中np.lib.stride_tricks.as_strided是一个强大且灵活的函数,能够让我们以非传统的方式“查看”数组,从而实现高效的块级操作。本教程将深入探讨如何结合as_strided和查找表(Lookup Table, LUT)来高效地修改NumPy二维数组的2x2块。
np.lib.stride_tricks.as_strided是一个用于创建数组新视图的函数。它的强大之处在于,你可以手动指定新视图的形状(shape)和步长(strides),而无需复制原始数据。这意味着对视图的修改会直接反映在原始数组上,极大地提高了内存效率和操作速度。
要将一个二维数组A(例如ny行nx列)转换为一个由2x2块组成的视图,我们需要理解shape和strides的含义:
综合起来,strides参数将是(A.strides[0]*2, A.strides[1]*2, A.strides[0], A.strides[1])。
代码示例1:创建块视图
import numpy as np
# 假设原始数组A是一个10x10的0/1值数组
A = np.random.randint(0, 2, (10, 10))
print("原始数组 A:\n", A)
# 计算新视图的形状
# 如果A是(ny, nx),那么块视图的形状是(ny//2, nx//2, 2, 2)
block_rows = A.shape[0] // 2
block_cols = A.shape[1] // 2
# 创建块视图
# A.strides[0] 是行步长,A.strides[1] 是列步长
Av = np.lib.stride_tricks.as_strided(A,
shape=(block_rows, block_cols, 2, 2),
strides=(A.strides[0] * 2, A.strides[1] * 2, A.strides[0], A.strides[1]))
print("\n块视图 Av 的形状:", Av.shape)
# 验证 Av[0,0] 是否是 A 的左上角2x2块
print("\nAv[0,0] (第一个2x2块):\n", Av[0, 0])
print("\nA[0:2, 0:2] (A的左上角2x2块):\n", A[0:2, 0:2])
# 验证修改Av会影响A
Av[0, 0] = [[9, 9], [9, 9]]
print("\n修改 Av[0,0] 后,A 的左上角2x2块:\n", A[0:2, 0:2])一旦我们有了块视图Av,就可以使用查找表来根据每个2x2块的当前值来决定其新的值。查找表的构建方式可以有多种,这里介绍两种常见且高效的方法。
假设我们的2x2块中的元素都是0或1(布尔值)。一个2x2的块共有 $2^4 = 16$ 种可能的组合。
这种方法为查找表lut创建多个维度,每个维度对应2x2块中的一个元素的值。例如,一个lut的形状可以是(2, 2, 2, 2, 2, 2),其中前四个2代表输入块的四个元素([0,0], [0,1], [1,0], [1,1])的可能值(0或1),后两个2代表输出的2x2块。
构建查找表:
# lut 的形状:(输入块[0,0], 输入块[0,1], 输入块[1,0], 输入块[1,1], 输出块行, 输出块列) lut = np.zeros((2, 2, 2, 2, 2, 2), dtype=A.dtype) # 填充一些转换规则 (示例,根据实际需求定义) # 假设输入块 [[0,0],[0,0]] 转换为 [[1,1],[1,1]] lut[0, 0, 0, 0] = [[1, 1], [1, 1]] # 假设输入块 [[0,0],[0,1]] 转换为 [[1,1],[1,0]] lut[0, 0, 0, 1] = [[1, 1], [1, 0]] # 假设输入块 [[1,1],[0,0]] 转换为 [[1,1],[1,1]] lut[1, 1, 0, 0] = [[1, 1], [1, 1]] # 其他未定义的组合将保持为0(根据lut的初始化)
应用查找表:
通过高级索引,我们可以直接将Av中每个2x2块的四个元素作为lut的索引,从而一次性完成所有块的转换。
# 重新初始化A以进行演示
A = np.random.randint(0, 2, (10, 10))
print("应用多维LUT前的 A:\n", A)
block_rows = A.shape[0] // 2
block_cols = A.shape[1] // 2
Av = np.lib.stride_tricks.as_strided(A,
shape=(block_rows, block_cols, 2, 2),
strides=(A.strides[0] * 2, A.strides[1] * 2, A.strides[0], A.strides[1]))
# 使用高级索引应用LUT
# Av[...,0,0] 获取所有块的[0,0]元素组成的数组
# Av[...,0,1] 获取所有块的[0,1]元素组成的数组,以此类推
Av[:] = lut[Av[..., 0, 0], Av[..., 0, 1], Av[..., 1, 0], Av[..., 1, 1]]
print("\n应用多维LUT后的 A:\n", A)这种方法首先将每个2x2的0/1块转换成一个单一的整数索引(0-15),然后使用这个整数索引来查找一个一维的查找表。这种方式可以使查找表的定义更紧凑。
将2x2块转换为单一索引:
一个2x2的0/1块可以看作一个4位的二进制数。例如,块[[a,b],[c,d]]可以转换为索引 a*8 + b*4 + c*2 + d*1。
# 定义权重矩阵 weights = np.array([[8, 4], [2, 1]]) # 计算每个2x2块的扁平化索引 # (Av * weights) 会对每个2x2块内部进行元素级乘法 # .sum(axis=(2,3)) 会将每个2x2块内部的元素求和,得到一个 (block_rows, block_cols) 形状的索引数组 idx = (Av * weights).sum(axis=(2, 3))
构建扁平化查找表:
lut2的形状将是(16, 2, 2),其中16代表所有可能的输入块索引。
lut2 = np.zeros((16, 2, 2), dtype=A.dtype) # 填充一些转换规则 (示例) # 索引0 (即块[[0,0],[0,0]]) 转换为 [[1,1],[1,1]] lut2[0] = [[1, 1], [1, 1]] # 索引1 (即块[[0,0],[0,1]]) 转换为 [[1,1],[1,0]] lut2[1] = [[1, 1], [1, 0]] # 索引12 (即块[[1,1],[0,0]]) 转换为 [[1,1],[1,1]] lut2[12] = [[1, 1], [1, 1]] # 其他未定义的组合将保持为0
应用查找表:
# 重新初始化A以进行演示
A = np.random.randint(0, 2, (10, 10))
print("应用扁平化LUT前的 A:\n", A)
block_rows = A.shape[0] // 2
block_cols = A.shape[1] // 2
Av = np.lib.stride_tricks.as_strided(A,
shape=(block_rows, block_cols, 2, 2),
strides=(A.strides[0] * 2, A.strides[1] * 2, A.strides[0], A.strides[1]))
# 计算扁平化索引
idx = (Av * weights).sum(axis=(2, 3))
# 使用扁平化索引应用LUT
Av[:] = lut2[idx]
print("\n应用扁平化LUT后的 A:\n", A)as_strided创建的视图支持常规的NumPy切片操作。这意味着你可以只对原始数组的某个特定区域的块进行修改,而不是整个数组。
# 重新初始化A
A = np.random.randint(0, 2, (10, 10))
print("进行局部修改前的 A:\n", A)
block_rows = A.shape[0] // 2
block_cols = A.shape[1] // 2
Av = np.lib.stride_tricks.as_strided(A,
shape=(block_rows, block_cols, 2, 2),
strides=(A.strides[0] * 2, A.strides[1] * 2, A.strides[0], A.strides[1]))
# 假设我们只想修改Av中索引为 (2,2) 到 (3,3) 的块区域
# 使用扁平化LUT进行修改
weights = np.array([[8, 4], [2, 1]])
lut2 = np.zeros((16, 2, 2), dtype=A.dtype)
lut2[0] = [[1, 1], [1, 1]] # 示例规则
# 计算指定区域块的索引
idx_partial = (Av[2:4, 2:4] * weights).sum(axis=(2, 3))
# 对指定区域的块进行修改
Av[2:4, 2:4] = lut2[idx_partial]
print("\n进行局部修改后的 A:\n", A)通过巧妙地运用np.lib.stride_tricks.as_strided创建数组的块视图,并结合查找表机制,我们可以高效、内存友好地对NumPy二维数组的固定大小块进行批量修改。这种方法将复杂的循环逻辑转化为NumPy底层的向量化操作,显著提升了处理大规模数据的性能。掌握这一高级技巧,将使你在NumPy数据处理中如虎添翼。
以上就是使用NumPy高效修改二维数组:2x2块操作的Stride Tricks技巧的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号