
在数据处理中,我们经常会遇到多维数组中包含缺失值(nan)的情况。例如,一个3d numpy数组可能代表了多组(第一维度)2d数据,每组2d数据又包含行和列。我们的目标是针对每一组2d数据,计算其所有列的均值,同时忽略计算中的nan值,然后用这些计算出的列均值来填充原始数组中对应列的nan值。
考虑以下一个形状为(2, 3, 3)的3D NumPy数组作为示例:
import numpy as np
a = np.array([[[1, 2, 3], [4, np.nan, 6], [7, 8, 9]],
[[11, 12, 13], [14, np.nan, 16], [17, 18, 19]]])
print("原始数组形状:", a.shape)
print("原始数组:\n", a)输出:
原始数组形状: (2, 3, 3) 原始数组: [[[ 1. 2. 3.] [ 4. nan 6.] [ 7. 8. 9.]] [[11. 12. 13.] [14. nan 16.] [17. 18. 19.]]]
在这个数组中,a[0]和a[1]分别代表了两组2D数据。我们希望对a[0]的第二列(索引为1)计算均值,即(2 + 8) / 2 = 5,然后用5填充a[0, 1, 1]处的NaN。同样,对于a[1]的第二列,计算均值(12 + 18) / 2 = 15,并用15填充a[1, 1, 1]处的NaN。
期望的结果数组如下:
[[[ 1., 2., 3.], [ 4., 5., 6.], [ 7., 8., 9.]], [[11., 12., 13.], [14., 15., 16.], [17., 18., 19.]]]
NumPy提供了一个专门用于处理包含NaN值的均值计算函数 np.nanmean()。结合NumPy强大的广播(broadcasting)机制,我们可以高效地实现上述目标。
首先,我们需要计算每个2D子数组的列均值。对于一个形状为(dim0, dim1, dim2)的3D数组,如果我们想计算每个dim0切片(即每个2D子数组)的列均值,我们需要指定axis=1。这是因为axis=0代表第一个维度(2D子数组的索引),axis=1代表第二个维度(2D子数组的行索引),axis=2代表第三个维度(2D子数组的列索引)。当我们对axis=1求均值时,它会沿着行方向进行聚合,从而得到每列的均值。
# 计算每个2D子数组的列均值,忽略NaN值
# axis=1 表示在第二个维度上进行求均值操作,即对每个2D切片的列求均值
means = np.nanmean(a, axis=1)
print("\n计算出的列均值 (shape: {}):\n{}".format(means.shape, means))输出:
计算出的列均值 (shape: (2, 3)): [[ 4. 5. 6.] [14. 15. 16.]]
这里,means数组的形状是(2, 3)。means[0]对应原始数组a[0]的列均值 [4., 5., 6.],其中5.是(2+8)/2的结果。means[1]对应a[1]的列均值 [14., 15., 16.],其中15.是(12+18)/2的结果。
现在我们有了每个2D子数组的列均值,但means的形状是(2, 3),而原始数组a的形状是(2, 3, 3)。为了使用np.where函数将这些均值正确地广播到原始数组的相应NaN位置,我们需要将means的形状调整为(2, 1, 3)。通过在第二个维度上添加一个新轴(np.newaxis),可以实现这一点。
# 调整均值数组的形状,使其能够正确广播
# means[:, np.newaxis, :] 将形状从 (2, 3) 变为 (2, 1, 3)
means_reshaped = means[:, np.newaxis, :]
print("\n重塑后的列均值 (shape: {}):\n{}".format(means_reshaped.shape, means_reshaped))输出:
重塑后的列均值 (shape: (2, 1, 3)): [[[ 4. 5. 6.]] [[14. 15. 16.]]]
现在,means_reshaped的形状是(2, 1, 3)。当它与形状为(2, 3, 3)的原始数组a进行广播操作时:
最后一步是使用np.where()函数来条件性地替换NaN值。np.where(condition, x, y)的含义是:如果condition为真,则取x中的值;否则,取y中的值。
# 使用np.where函数填充NaN值
# 如果a中的元素是NaN,则用重塑后的列均值填充;否则保留a中的原始值
a_filled = np.where(np.isnan(a), means_reshaped, a)
print("\n填充NaN后的数组:\n", a_filled)输出:
填充NaN后的数组: [[[ 1. 2. 3.] [ 4. 5. 6.] [ 7. 8. 9.]] [[11. 12. 13.] [14. 15. 16.] [17. 18. 19.]]]
可以看到,原始数组中的NaN值已经被正确地替换为对应列的均值。
import numpy as np
# 原始3D数组,包含NaN值
a = np.array([[[1, 2, 3], [4, np.nan, 6], [7, 8, 9]],
[[11, 12, 13], [14, np.nan, 16], [17, 18, 19]]])
print("原始数组:\n", a)
print("原始数组形状:", a.shape)
# 1. 计算每个2D子数组的列均值,忽略NaN
# axis=1 表示在第二个维度上进行求均值,即对每个2D切片的列求均值
means = np.nanmean(a, axis=1)
print("\n计算出的列均值 (shape: {}):\n{}".format(means.shape, means))
# 2. 调整均值数组的形状以进行广播
# np.newaxis 在指定位置插入一个新维度,将 (2, 3) 变为 (2, 1, 3)
means_reshaped = means[:, np.newaxis, :]
print("\n重塑后的列均值 (shape: {}):\n{}".format(means_reshaped.shape, means_reshaped))
# 3. 使用np.where填充NaN值
# 如果a中的元素是NaN,则用重塑后的列均值填充;否则保留a中的原始值
a_filled = np.where(np.isnan(a), means_reshaped, a)
print("\n填充NaN后的数组:\n", a_filled)通过掌握np.nanmean()、np.newaxis和np.where()的组合使用,可以高效且优雅地处理NumPy多维数组中包含NaN值的复杂数据清洗和填充任务。
以上就是利用NumPy处理3D数组中包含NaN值的列均值计算与填充的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号