
在数据处理中,我们经常需要对包含数组类型列的spark dataframe进行过滤。一个常见的需求是,筛选出那些数组列中至少包含给定python列表(例如 [item1, item2, ...])中一个或多个元素的行。
在SQL中,这种操作非常直观,通常可以使用arrays_overlap函数:
SELECT <columns> FROM <table> WHERE arrays_overlap(<array_column>, array(<list_elements>))
然而,当尝试将这种逻辑直接转换为PySpark时,许多用户会遇到困难。一个常见的错误尝试是:
from pyspark.sql.functions import col, array, arrays_overlap
# 假设 target_list 是一个 Python 列表,如 ['apple', 'banana']
df.filter(arrays_overlap(col("array_column"), array(target_list)))这段代码通常会导致AnalysisException,错误信息类似于[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name '<1st item in list>' cannot be resolved.。这是因为array()函数在接收非列参数时,期望的是字面量表达式(literal expressions),而不是原始的Python列表元素。虽然array_contains函数可以处理单个元素,但它无法满足与整个列表进行交集判断的需求。
解决这个问题的关键在于,将Python列表中的每个元素转换为Spark的字面量表达式(literal expression),然后再用array函数将其组合成一个字面量数组。这可以通过pyspark.sql.functions.lit函数来实现。
lit函数的作用是将一个Python值转换为一个Spark列表达式,这个表达式代表着一个常量值。当我们将列表中的每个元素都通过lit转换后,再将这些字面量表达式传递给array函数,array函数就能正确地构建一个包含这些字面量值的数组。
正确的PySpark实现如下:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, array, arrays_overlap, lit
# 1. 初始化 SparkSession
spark = SparkSession.builder.appName("ArrayColumnFilter").getOrCreate()
# 2. 准备示例数据
data = [
    (1, ["apple", "banana", "orange"]),
    (2, ["grape", "kiwi"]),
    (3, ["banana", "strawberry"]),
    (4, ["mango", "pineapple"]),
    (5, ["apple", "grape"])
]
df = spark.createDataFrame(data, ["id", "fruits_array"])
df.printSchema()
df.show()
# 3. 定义用于过滤的 Python 列表
target_list = ["banana", "grape", "lemon"]
# 4. 构建正确的过滤条件
# 使用 map(lit, target_list) 将列表中的每个元素转换为 lit 表达式
# 使用 * 解包这些 lit 表达式作为 array 函数的参数
# 最后,使用 arrays_overlap 进行比较
filtered_df = df.filter(
    arrays_overlap(col("fruits_array"), array(*map(lit, target_list)))
)
# 5. 显示过滤结果
print(f"\n原始DataFrame:")
df.show()
print(f"\n过滤列表:{target_list}")
print("\n过滤后的DataFrame(fruits_array与target_list有交集):")
filtered_df.show()
# 6. 停止 SparkSession
spark.stop()运行结果示例:
root |-- id: long (nullable = true) |-- fruits_array: array (nullable = true) | |-- element: string (nullable = true) +---+--------------------+ | id| fruits_array| +---+--------------------+ | 1|[apple, banana, o...| | 2| [grape, kiwi]| | 3|[banana, strawber...| | 4|[mango, pineapple]| | 5| [apple, grape]| +---+--------------------+ 原始DataFrame: +---+--------------------+ | id| fruits_array| +---+--------------------+ | 1|[apple, banana, o...| | 2| [grape, kiwi]| | 3|[banana, strawber...| | 4|[mango, pineapple]| | 5| [apple, grape]| +---+--------------------+ 过滤列表:['banana', 'grape', 'lemon'] 过滤后的DataFrame(fruits_array与target_list有交集): +---+--------------------+ | id| fruits_array| +---+--------------------+ | 1|[apple, banana, o...| | 2| [grape, kiwi]| | 3|[banana, strawber...| | 5| [apple, grape]| +---+--------------------+
从结果可以看出,id为1、2、3、5的行被保留,因为它们的fruits_array列与["banana", "grape", "lemon"]存在交集(例如,id=1包含"banana",id=2包含"grape",id=3包含"banana",id=5包含"grape")。
在PySpark中,当需要使用一个Python列表与DataFrame的数组列进行交集过滤时,务必记住使用pyspark.sql.functions.lit函数将列表中的每个元素转换为Spark字面量表达式。然后,通过array(*map(lit, your_list))的方式构建一个字面量数组,并将其作为arrays_overlap函数的第二个参数。这种模式是处理这类复杂数组过滤逻辑的标准且正确的方法,能够确保代码的健壮性和准确性。
以上就是在PySpark中利用数组列与列表交集进行DataFrame过滤的正确姿势的详细内容,更多请关注php中文网其它相关文章!
                        
                        每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
                Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号