
在使用 python 进行数据可视化时,matplotlib 是一个功能强大的库,而 shap 库在生成解释性图表时通常会利用 matplotlib 进行渲染。当用户尝试将 shap.summary_plot 生成的图表保存为图像文件时,常见的问题是直接调用 matplotlib.pyplot.savefig() 可能会保存一个空白的图片。
这通常是由于 matplotlib 对图表(Figure)和坐标轴(Axes)的内部管理机制造成的。matplotlib.pyplot 模块提供了一系列便捷函数,它们通常操作“当前”的图表和坐标轴。如果 shap.summary_plot 在内部创建了一个新的图表对象,或者在绘制完成后将其关闭,那么紧接着调用的 plt.savefig() 可能会尝试保存一个默认的、空的“当前”图表,而不是我们期望的 SHAP 图。
解决此问题的关键在于显式地创建和管理 matplotlib 的 Figure 对象。通过创建一个 Figure 实例,然后确保 SHAP 图绘制在这个特定的 Figure 上,最后再通过该 Figure 实例的方法来保存图像,可以确保保存的是正确的图表。
具体步骤如下:
下面通过一个具体的代码示例来演示如何正确地保存 SHAP Summary Plot。我们将沿用原始问题中的模型和数据结构,但重点放在 SHAP 图的生成与保存上。
首先,确保安装了必要的库:numpy, tensorflow, shap, matplotlib。
import numpy as np
import shap
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
# 示例数据 (简化,仅用于演示目的)
X = np.array([[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)],
[(4,5,6,4,4),(5,6,4,3,2),(5,5,6,1,3),(3,3,3,2,2),(2,3,3,2,1)],
[(7,8,9,4,7),(7,7,6,7,8),(5,8,7,8,8),(6,7,6,7,8),(5,7,6,6,6)],
[(7,8,9,8,6),(6,6,7,8,6),(8,7,8,8,8),(8,6,7,8,7),(8,6,7,8,8)],
[(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
[(4,5,6,5,5),(5,5,5,6,4),(6,5,5,5,6),(4,4,3,3,3),(5,5,4,4,5)],
[(1,2,3,3,1),(3,2,1,3,2),(3,2,2,3,3),(2,2,1,1,2),(2,1,1,1,1)]])
y = np.array([0, 1, 2, 2, 1, 1, 0])
# 构建并编译模型
model = keras.Sequential([
layers.Conv1D(128, kernel_size=3, activation='relu', input_shape=(5,5)),
layers.MaxPooling1D(pool_size=2),
layers.LSTM(128, return_sequences=True),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(3, activation='softmax') # 假设3个类别,与y的实际值对应
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 训练模型
print("开始训练模型...")
model.fit(X, y, epochs=10, verbose=0) # verbose=0 不显示训练进度
print("模型训练完成。")接下来,使用训练好的模型和数据计算 SHAP 值。
# 初始化 SHAP explainer explainer = shap.GradientExplainer(model, X) shap_values = explainer.shap_values(X) # 定义要绘制的类别和特征索引 # shap_values 是一个列表,每个元素对应一个输出类别。 # 对于多分类模型,shap_values[cls] 是对应类别下的SHAP值数组。 # shap_values[cls][:,idx,:] 表示该类别下,所有样本的第idx个特征的所有维度SHAP值。 # X[:,idx,:] 对应所有样本的第idx个特征的所有维度原始数据。 cls = 0 # 示例:选择第一个输出类别 idx = 0 # 示例:选择第一个特征的SHAP值
这是核心部分,演示如何显式管理 matplotlib 图形对象以正确保存 SHAP 图。
# 1. 初始化一个 matplotlib Figure 对象
fig = plt.figure()
# 2. 绘制 SHAP summary_plot 到当前 Figure 上
# 注意:这里我们传入了 shap_values[cls][:,idx,:] 和 X[:,idx,:]
# 这意味着我们正在可视化特定类别 (cls) 下,特定特征索引 (idx) 的SHAP值。
# 根据你的模型输出和特征结构,你可能需要调整 cls 和 idx。
shap.summary_plot(shap_values[cls][:,idx,:], X[:,idx,:], show=False) # show=False 防止立即显示图表
# 3. 定义保存路径
save_path = 'shap_summary_plot.png'
# 4. 通过 Figure 对象保存图表
fig.savefig(save_path, bbox_inches='tight', dpi=300) # bbox_inches='tight' 裁剪空白边缘,dpi设置分辨率
# 5. 关闭 Figure 对象,释放内存
plt.close(fig)
print(f"SHAP summary plot 已成功保存到:{save_path}")fig.savefig('shap_summary_plot.svg', bbox_inches='tight') # 保存为SVG矢量图# 示例:保存第二个类别的SHAP图
fig2 = plt.figure()
shap.summary_plot(shap_values[1][:,idx,:], X[:,idx,:], show=False)
fig2.savefig('shap_summary_plot_cls1.png', bbox_inches='tight', dpi=300)
plt.close(fig2)正确保存 shap.summary_plot 的关键在于理解 matplotlib 的图表管理机制。通过显式地创建 Figure 对象,并在其上进行绘图,然后使用 Figure 实例的 savefig 方法,可以确保图表内容被准确地捕获和保存。这种方法不仅解决了空白图片的问题,也使得图表管理更加清晰和可控,是进行专业数据可视化输出的推荐实践。
以上就是如何将 SHAP Summary Plot 保存为高质量图像文件的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号