
在数据处理中,我们经常会遇到数据以数组形式存储在 DataFrame 的列中。例如,一个数据框可能包含一个 id 数组列和一个 label 数组列,它们是按索引一一对应的。我们的目标是从 label 数组中找到最大值,并获取 id 数组中对应索引位置的元素,同时保留原始行的其他信息。
考虑以下 PySpark DataFrame 示例:
+-----------+-----------+------+ | id | label | md | +-----------+-----------+------+ |[a, b, c] | [1, 4, 2] | 3 | |[b, d] | [7, 2] | 1 | |[a, c] | [1, 2] | 8 |
我们期望的输出是:
+----+-----+------+ | id |label| md | +----+-----+------+ | b | 4 | 3 | | b | 7 | 1 | | c | 2 | 8 |
这要求我们能够将两个数组列的元素按索引进行配对,然后对配对后的值进行聚合操作。
为了解决上述问题,我们将利用 PySpark 的几个核心函数:
整个流程可以概括为:将 id 和 label 数组按元素配对并展开成多行,然后对展开后的数据使用窗口函数找出每组的最大 label 值及其对应的 id。
首先,我们需要一个 SparkSession 并创建与问题描述相符的示例 DataFrame。
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
# 初始化 SparkSession
spark = SparkSession.builder \
.appName("GetMaxFromArrayColumn") \
.getOrCreate()
# 创建示例数据
data = [
(["a", "b", "c"], [1, 4, 2], 3),
(["b", "d"], [7, 2], 1),
(["a", "c"], [1, 2], 8)
]
df = spark.createDataFrame(data, ["id", "label", "md"])
df.show(truncate=False)输出:
+---------+---------+---+ |id |label |md | +---------+---------+---+ |[a, b, c]|[1, 4, 2]|3 | |[b, d] |[7, 2] |1 | |[a, c] |[1, 2] |8 | +---------+---------+---+
使用 arrays_zip 将 id 和 label 列合并成一个结构体数组。例如,[a,b,c] 和 [1,4,2] 会变成 [{id:a, label:1}, {id:b, label:4}, {id:c, label:2}]。 然后,使用 inline 函数将这个结构体数组扁平化。inline 会将数组中的每个结构体转换为 DataFrame 的一行,并将其字段作为新的列。
# 使用 selectExpr 结合 inline 和 arrays_zip
# 原始的 'md' 列会被保留,而 'id' 和 'label' 列会被扁平化
df_exploded = df.selectExpr("md", "inline(arrays_zip(id, label))")
df_exploded.show(truncate=False)输出:
+---+----+-----+ |md |id |label| +---+----+-----+ |3 |a |1 | |3 |b |4 | |3 |c |2 | |1 |b |7 | |1 |d |2 | |8 |a |1 | |8 |c |2 | +---+----+-----+
现在,每一行代表了原始数组中的一个 (id, label) 对,并且 md 列标识了它们所属的原始行。
接下来,我们需要在每个原始行(由 md 列标识)的上下文中找到 label 列的最大值。这可以通过定义一个窗口并应用 max 聚合函数来实现。
# 定义窗口,按 'md' 列分区
# 这里的 'md' 列被假定为原始行的唯一标识符
w = Window.partitionBy("md")
# 在每个窗口内计算 'label' 列的最大值,并将其作为新列 'mx_label' 添加
df_with_max_label = df_exploded.withColumn("mx_label", F.max("label").over(w))
df_with_max_label.show(truncate=False)输出:
+---+----+-----+--------+ |md |id |label|mx_label| +---+----+-----+--------+ |1 |b |7 |7 | |1 |d |2 |7 | |3 |a |1 |4 | |3 |b |4 |4 | |3 |c |2 |4 | |8 |a |1 |2 | |8 |c |2 |2 | +---+----+-----+--------+
最后一步是过滤出那些 label 值等于其所在组最大 label 值的行,然后删除辅助列 mx_label。
# 过滤出 label 等于 mx_label 的行
final_df = df_with_max_label.filter(F.col("label") == F.col("mx_label")) \
.drop("mx_label")
# 根据期望输出调整列的顺序
final_df = final_df.select("id", "label", "md")
final_df.show(truncate=False)输出:
+---+-----+---+ |id |label|md | +---+-----+---+ |b |7 |1 | |b |4 |3 | |c |2 |8 | +---+-----+---+
这与我们期望的输出完全一致。
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window
# 初始化 SparkSession
spark = SparkSession.builder \
.appName("GetMaxFromArrayColumn") \
.getOrCreate()
# 创建示例数据
data = [
(["a", "b", "c"], [1, 4, 2], 3),
(["b", "d"], [7, 2], 1),
(["a", "c"], [1, 2], 8)
]
df = spark.createDataFrame(data, ["id", "label", "md"])
print("原始 DataFrame:")
df.show(truncate=False)
# 步骤1 & 2: 合并 'id' 和 'label' 数组并扁平化
# 使用 selectExpr 结合 inline 和 arrays_zip
df_exploded = df.selectExpr("md", "inline(arrays_zip(id, label))")
print("扁平化后的 DataFrame:")
df_exploded.show(truncate=False)
# 步骤3: 定义窗口并计算每个原始行的最大 'label' 值
# 假设 'md' 列唯一标识原始 DataFrame 的每一行
w = Window.partitionBy("md")
df_with_max_label = df_exploded.withColumn("mx_label", F.max("label").over(w))
print("添加最大值列后的 DataFrame:")
df_with_max_label.show(truncate=False)
# 步骤4 & 5: 过滤出最大值对应的行并删除辅助列,调整列顺序
final_df = df_with_max_label.filter(F.col("label") == F.col("mx_label")) \
.drop("mx_label") \
.select("id", "label", "md") # 调整列顺序
print("最终结果 DataFrame:")
final_df.show(truncate=False)
# 停止 SparkSession
spark.stop()w_ordered = Window.partitionBy("md").orderBy(F.col("label").desc(), F.lit(1)) # lit(1) for stable order if labels are equal
df_with_rank = df_exploded.withColumn("rank", F.row_number().over(w_ordered))
final_df = df_with_rank.filter(F.col("rank") == 1).drop("rank")本教程展示了如何利用 PySpark 的 arrays_zip、inline 和窗口函数来高效地解决从数组列中提取最大值及其对应索引元素的问题。这种组合方法是处理复杂数组操作的强大工具,能够保持代码的简洁性和执行效率,是 PySpark 数据处理中值得掌握的技巧。理解这些函数的协同工作方式,有助于在面对类似数组转换需求时构建健壮且高性能的解决方案。
以上就是PySpark 数据框中从一个数组列获取最大值并从另一列获取对应索引值的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号