
本文深入探讨了在mpi4py中使用`comm.Gather`处理不同形状NumPy数组时遇到的挑战,并提供了两种有效的解决方案:利用`comm.gather`收集通用Python对象后进行拼接,以及使用`comm.Gatherv`直接将不同大小的数组高效地集合到一个预分配的NumPy缓冲区中。文章将详细阐述这两种方法的实现细节、适用场景及代码示例,帮助开发者优化并行程序的集合通信效率。
在并行计算中,经常需要在各个进程(或核心)上处理数据,然后将这些分散的结果收集到根进程上进行进一步的分析或整合。mpi4py库提供了强大的MPI(Message Passing Interface)绑定,使得Python程序能够方便地进行并行化。其中,comm.Gather是一个常用的集体通信操作,用于将所有进程的相同类型和形状的数据收集到根进程的一个连续缓冲区中。
然而,当每个进程需要发送的NumPy数组形状不一致时,直接使用comm.Gather会导致程序失败,因为它期望所有发送的数据都具有相同的维度和大小。本文将介绍两种在mpi4py中有效处理不同形状NumPy数组集合操作的方法:comm.gather(小写g)和comm.Gatherv(大写G,小写v)。
comm.Gather操作的本质是将所有进程的相同类型数据按顺序收集到根进程的一个预定义缓冲区中。这意味着每个发送进程的数据必须是同构的,即具有相同的形状和数据类型。
考虑以下示例,其中不同进程生成了形状不同的NumPy数组:
from mpi4py import MPI
import numpy as np
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
# rank 1生成(2,3)的数组,其他进程生成(5,3)的数组
a = np.zeros((2 if rank == 1 else 5, 3), dtype=float) + rank
print(f"Rank {rank}: 数组形状 {a.shape}, 数据:\n{a}")
# 尝试使用comm.Gather,这通常会失败
# b = np.zeros((12, 3), dtype=float) - 1 # 假设一个足够大的接收缓冲区
# comm.Gather(a, b, root=0)
# if rank == 0:
# print(f"Rank {rank}: 接收到的数据:\n{b}")运行上述代码中被注释掉的comm.Gather部分,会因为数组形状不匹配而导致运行时错误。为了解决这个问题,我们需要采用更灵活的集合通信方法。
comm.gather(注意是小写g)是mpi4py中一个更通用的集合操作。它不局限于NumPy数组,可以收集任何可序列化的Python对象。当每个进程发送的NumPy数组形状不同时,comm.gather会将其作为独立的Python对象进行收集,并在根进程上返回一个包含这些对象的列表或元组。随后,我们可以使用numpy.concatenate将这些数组拼接起来。
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
# rank 1生成(2,3)的数组,其他进程生成(5,3)的数组
a = np.zeros((2 if rank == 1 else 5, 3), dtype=float) + rank
print(f"Rank {rank}: 数组形状 {a.shape}, 数据:\n{a}")
# 使用comm.gather收集不同形状的数组
# 根进程会收到一个包含所有数组的列表
gathered_arrays = comm.gather(a, root=0)
if rank == 0:
print(f"\nRank {rank}: 原始收集到的数据 (列表形式):\n{gathered_arrays}")
# 将收集到的数组列表拼接成一个大数组
concatenated_array = np.concatenate(gathered_arrays, axis=0) # 沿0轴拼接
print(f"\nRank {rank}: 拼接后的数据形状 {concatenated_array.shape}, 数据:\n{concatenated_array}")
else:
# 非根进程的gathered_arrays为None
print(f"Rank {rank}: 非根进程,gathered_arrays为 {gathered_arrays}")comm.Gatherv(注意是Gatherv)是comm.Gather的变体,专门设计用于处理每个进程发送数据大小不同的情况。它允许将来自不同进程的、大小不一的数据直接集合到根进程的一个预分配的NumPy数组中,而无需中间的Python对象列表和后续的拼接操作。这通常在性能要求较高的场景下更为高效。
comm.Gatherv的接收缓冲区参数比comm.Gather复杂,它是一个元组,通常格式为 (recvbuf, recvcounts, displs, recvtype):
为了简化recvcounts和displs的计算,以下示例假设只有两个进程(size <= 2),但在实际应用中,这些参数通常需要通过comm.allgather或comm.gather先收集每个进程的形状信息来动态计算。
import numpy as np
from mpi4py import MPI
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
# 示例限制为两个进程,以便手动设置recvcounts和displs
assert size <= 2, "此Gatherv示例仅适用于2个或更少的进程"
if rank == 0:
a = np.zeros((5, 3), dtype=float) + rank
else: # rank == 1
a = np.zeros((2, 3), dtype=float) + rank
print(f"Rank {rank}: 数组形状 {a.shape}, 数据:\n{a}")
# 定义全局总行数 (5来自rank 0, 2来自rank 1)
n_global_rows = 7
# 定义每行元素数
cols = a.shape[1]
# 根进程需要预分配接收缓冲区
if rank == 0:
b = np.zeros((n_global_rows, cols), dtype=float)
# 计算每个进程发送的元素数量
# rank 0: 5行 * 3列 = 15个元素
# rank 1: 2行 * 3列 = 6个元素
recvcounts = [5 * cols, 2 * cols] # 对应每个进程的元素总数
# 计算每个进程数据在b中的起始偏移量 (元素偏移量)
# rank 0: 从b的0偏移量开始
# rank 1: 从b的第15个元素 (即第5行3列后) 开始
displs = [0, 5 * cols]
# 组合Gatherv的接收缓冲区参数
recvbuf_params = (b, recvcounts, displs, MPI.DOUBLE)
else:
b = None
recvbuf_params = None # 非根进程不需要提供接收缓冲区参数
# 执行Gatherv操作
comm.Gatherv(a, recvbuf_params, root=0)
if rank == 0:
print(f"\nRank {rank}: Gatherv接收到的数据形状 {b.shape}, 数据:\n{b}")
else:
print(f"Rank {rank}: 非根进程,b为 {b}")
当需要在mpi4py中将不同形状的NumPy数组收集到根进程时:
comm.gather (小写g):
comm.Gatherv (大写G,小写v):
在实际开发中,应根据具体的应用需求(数据规模、性能要求、代码复杂度等)权衡选择合适的方法。对于大多数情况,如果性能不是极致瓶颈,comm.gather配合np.concatenate是一个简单有效的方案。而对于高性能计算场景,comm.Gatherv则是更专业的选择。
以上就是mpi4py中异形NumPy数组的集合操作:gather与Gatherv详解的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号