0

0

基于PaddlePaddle的中风患者线性模型训练和预测

P粉084495128

P粉084495128

发布时间:2025-07-28 11:44:49

|

746人浏览过

|

来源于php中文网

原创

该研究基于PaddlePaddle构建中风患者预测模型。使用含4981条数据的数据集,含性别、年龄等11个特征。先对分类变量序列化、数据标准化,切分训练集(80%)和测试集(20%),通过协相关分析特征关系。构建含全连接层的网络,经6轮训练,训练集正确率约91%,测试集评估正确率87.65%,模型有一定预测能力。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

基于paddlepaddle的中风患者线性模型训练和预测 - php中文网

一、基于PaddlePaddle的中风患者线性模型预测

基于PaddlePaddle的中风患者线性模型训练和预测 - php中文网        

1.背景描述

中风是一种医学疾病,由于流向大脑的血液不足导致细胞死亡。中风主要有两种类型:缺血性中风(缺乏血液流动导致)和出血性中风(出血导致)。两者都会导致大脑的某些部分停止正常运作。

中风的体征和症状可能包括一侧身体无法移动或感觉,理解或说话问题,头晕或一侧视力丧失。症状和体征通常在中风发生后不久就会出现。
如果症状持续不到一两个小时,中风就是短暂性脑缺血发作(TIA),也称为小中风。
出血性中风还可能伴有严重的头痛。中风的症状可能是永久性的。长期并发症可能包括肺炎和膀胱失控。

中风的主要危险因素是高血压。
其他危险因素包括高血胆固醇、吸烟、肥胖、糖尿病、以前的TIA、终末期肾病和心房颤动。

缺血性中风通常是由血管堵塞引起的,尽管也有一些不太常见的原因。
出血性中风是由出血直接进入大脑或进入大脑膜之间的空间引起的。

出血可能是由于脑动脉瘤破裂引起的。诊断通常基于身体检查,并辅以医学成像,如CT扫描或MRI扫描。
CT扫描可以排除出血,但不一定排除缺血,早期的CT扫描通常不会显示缺血。其他检查,如心电图(ECG)和血液检查,以确定危险因素和排除其他可能的原因。低血糖也可能引起类似的症状。

Digram
Digram

让Figma更好用的AI神器

下载

2.数据说明

《中国成人超重和肥胖症预防控制指南》的BMI分类:

BMI 身体质量指数说明
体重过轻
18.5 - 23.9 体重正常
24 - 27.9 超重
> 28 肥胖
  • 血糖水平
    正常空腹血糖浓度的预期值介于 70 mg/dL 到 100 mg/dL 之间。

    或:3.9 mmol/L 和 5.6 mmol/L 之间

二、数据分析

1.基础分析

In [1]
import numpy as npimport pandas as pd
   
In [2]
data = pd.read_csv('data/data225165/brain_stroke.csv',encoding='gbk')
data.head()
       
   性别    年龄  是否患有高血压  是否患有心脏病  是否有过婚姻  工作类型 住宅类型    血糖水平   BMI  吸烟状况  是否中风
0  男性  67.0        0        1       1  私人企业   城市  228.69  36.6  以前吸烟     1
1  男性  80.0        0        1       1  私人企业   农村  105.92  32.5  从不吸烟     1
2  女性  49.0        0        0       1  私人企业   城市  171.23  34.4    吸烟     1
3  女性  79.0        1        0       1  自雇人士   农村  174.12  24.0  从不吸烟     1
4  男性  81.0        0        0       1  私人企业   城市  186.21  29.0  以前吸烟     1
               
In [3]
data[data.duplicated()]
       
Empty DataFrame
Columns: [性别, 年龄, 是否患有高血压, 是否患有心脏病, 是否有过婚姻, 工作类型, 住宅类型, 血糖水平, BMI, 吸烟状况, 是否中风]
Index: []
               
In [4]
data.isnull().sum()
       
性别         0
年龄         0
是否患有高血压    0
是否患有心脏病    0
是否有过婚姻     0
工作类型       0
住宅类型       0
血糖水平       0
BMI        0
吸烟状况       0
是否中风       0
dtype: int64
               
In [5]
data.shape
       
(4981, 11)
               
In [6]
# 判断data中各个字段的取值是否与数据字典中一致# 以及判断是否存在额外的空值,如空格表示的空值for column in data.columns:    print(column + ":" + str(data[column].unique()))
       
性别:['男性' '女性']
年龄:[6.70e+01 8.00e+01 4.90e+01 7.90e+01 8.10e+01 7.40e+01 6.90e+01 7.80e+01
 6.10e+01 5.40e+01 5.00e+01 6.40e+01 7.50e+01 6.00e+01 7.10e+01 5.20e+01
 8.20e+01 6.50e+01 5.70e+01 4.20e+01 4.80e+01 7.20e+01 5.80e+01 7.60e+01
 3.90e+01 7.70e+01 6.30e+01 7.30e+01 5.60e+01 4.50e+01 7.00e+01 5.90e+01
 6.60e+01 4.30e+01 6.80e+01 4.70e+01 5.30e+01 3.80e+01 5.50e+01 4.60e+01
 3.20e+01 5.10e+01 1.40e+01 3.00e+00 8.00e+00 3.70e+01 4.00e+01 3.50e+01
 2.00e+01 4.40e+01 2.50e+01 2.70e+01 2.30e+01 1.70e+01 1.30e+01 4.00e+00
 1.60e+01 2.20e+01 3.00e+01 2.90e+01 1.10e+01 2.10e+01 1.80e+01 3.30e+01
 2.40e+01 3.60e+01 6.40e-01 3.40e+01 4.10e+01 8.80e-01 5.00e+00 2.60e+01
 3.10e+01 7.00e+00 1.20e+01 6.20e+01 2.00e+00 9.00e+00 1.50e+01 2.80e+01
 1.00e+01 1.80e+00 3.20e-01 1.08e+00 1.90e+01 6.00e+00 1.16e+00 1.00e+00
 1.40e+00 1.72e+00 2.40e-01 1.64e+00 1.56e+00 7.20e-01 1.88e+00 1.24e+00
 8.00e-01 4.00e-01 8.00e-02 1.48e+00 5.60e-01 1.32e+00 1.60e-01 4.80e-01]
是否患有高血压:[0 1]
是否患有心脏病:[1 0]
是否有过婚姻:[1 0]
工作类型:['私人企业' '自雇人士' '政府部门' '儿童']
住宅类型:['城市' '农村']
血糖水平:[228.69 105.92 171.23 ... 191.15  95.02  83.94]
BMI:[36.6 32.5 34.4 24.  29.  27.4 22.8 24.2 29.7 36.8 27.3 28.2 30.9 37.5
 25.8 37.8 22.4 48.9 26.6 27.2 23.5 28.3 44.2 25.4 22.2 30.5 26.5 33.7
 23.1 32.  29.9 23.9 28.5 26.4 20.2 33.6 38.6 39.2 27.7 31.4 36.5 33.2
 32.8 40.4 25.3 30.2 47.5 20.3 30.  28.9 28.1 31.1 21.7 27.  24.1 45.9
 44.1 22.9 29.1 32.3 41.1 25.6 29.8 26.3 26.2 29.4 24.4 28.  28.8 34.6
 19.4 30.3 41.5 22.6 27.1 31.3 31.  31.7 35.8 28.4 20.1 26.7 38.7 34.9
 25.  23.8 21.8 27.5 24.6 32.9 26.1 31.9 34.1 36.9 37.3 45.7 34.2 23.6
 22.3 37.1 45.  25.5 30.8 37.4 34.5 27.9 29.5 46.  42.5 35.5 26.9 45.5
 31.5 33.  23.4 30.7 20.5 21.5 40.  28.6 42.2 29.6 35.4 16.9 26.8 39.3
 32.6 35.9 21.2 42.4 40.5 36.7 29.3 19.6 18.  17.6 17.7 35.  22.  39.4
 19.7 22.5 25.2 41.8 23.7 24.5 31.2 16.  31.6 25.1 24.8 18.3 20.  19.5
 36.  35.3 40.1 43.1 21.4 34.3 27.6 16.5 24.3 25.7 21.9 38.4 25.9 18.6
 24.9 48.2 20.7 39.5 23.3 35.1 43.6 21.  47.3 16.6 21.6 15.5 35.6 16.7
 41.9 16.4 17.1 29.2 37.9 44.6 39.6 40.3 41.6 39.  23.2 18.9 36.1 36.3
 46.5 16.8 46.6 35.2 20.9 31.8 15.3 38.2 45.2 17.  27.8 23.  22.1 26.
 44.3 39.7 34.7 21.3 41.2 34.8 19.2 35.7 40.8 24.7 19.  32.4 34.  28.7
 32.1 20.4 30.6 19.3 40.9 17.2 16.1 16.2 40.6 18.4 21.1 42.3 32.2 17.5
 42.1 47.8 20.8 30.1 17.3 36.4 36.2 14.4 43.  41.7 33.8 43.9 22.7 18.7
 37.  38.5 16.3 44.  32.7 40.2 33.3 17.4 41.3 14.6 17.8 46.1 33.1 18.1
 43.8 38.9 43.7 39.9 15.9 19.8 38.3 41.  42.6 43.4 15.1 20.6 33.5 43.2
 19.1 30.4 38.  33.4 44.9 44.7 37.6 39.8 42.  37.2 42.8 18.8 42.9 14.3
 37.7 48.4 46.2 43.3 33.9 18.5 44.5 45.4 19.9 17.9 15.6 15.2 18.2 48.5
 14.1 15.7 44.8 38.1 44.4 38.8 39.1 41.4 14.2 15.4 45.1 48.7 42.7 48.8
 15.8 45.3 14.8 40.7 48.  46.8 48.3 14.5 15.  47.4 47.9 45.8 47.6 14.
 46.4 46.9 47.1 48.1 46.3 14.9]
吸烟状况:['以前吸烟' '从不吸烟' '吸烟' '不详']
是否中风:[1 0]
       

三、特征处理

1.特征分类变量序列化

In [7]
data.head()
       
   性别    年龄  是否患有高血压  是否患有心脏病  是否有过婚姻  工作类型 住宅类型    血糖水平   BMI  吸烟状况  是否中风
0  男性  67.0        0        1       1  私人企业   城市  228.69  36.6  以前吸烟     1
1  男性  80.0        0        1       1  私人企业   农村  105.92  32.5  从不吸烟     1
2  女性  49.0        0        0       1  私人企业   城市  171.23  34.4    吸烟     1
3  女性  79.0        1        0       1  自雇人士   农村  174.12  24.0  从不吸烟     1
4  男性  81.0        0        0       1  私人企业   城市  186.21  29.0  以前吸烟     1
               
In [8]
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
   
In [9]
# 除去序号列columns=data.columnsprint(len(columns))for column in columns:    print(column)
       
11
性别
年龄
是否患有高血压
是否患有心脏病
是否有过婚姻
工作类型
住宅类型
血糖水平
BMI
吸烟状况
是否中风
       
In [10]
label_colum_encoder = ['性别', '工作类型', '住宅类型', '吸烟状况' ]
   
In [11]
for column in label_colum_encoder:    print(f"完成 {column} 列序列化")
    data[column]=le.fit_transform(data[column])
       
完成 性别 列序列化
完成 工作类型 列序列化
完成 住宅类型 列序列化
完成 吸烟状况 列序列化
       
In [12]
data.head()
       
   性别    年龄  是否患有高血压  是否患有心脏病  是否有过婚姻  工作类型  住宅类型    血糖水平   BMI  吸烟状况  是否中风
0   1  67.0        0        1       1     2     1  228.69  36.6     2     1
1   1  80.0        0        1       1     2     0  105.92  32.5     1     1
2   0  49.0        0        0       1     2     1  171.23  34.4     3     1
3   0  79.0        1        0       1     3     0  174.12  24.0     1     1
4   1  81.0        0        0       1     2     1  186.21  29.0     2     1
               

2.数据标准化

In [13]
columns=['年龄','工作类型','血糖水平','BMI','吸烟状况']for column in columns:
    col = data[column]
    col_min = col.min()
    col_max = col.max()
    normalized = (col - col_min) / (col_max - col_min)
    data[column] = normalized
   
In [14]
data.head()
       
   性别        年龄  是否患有高血压  是否患有心脏病  是否有过婚姻      工作类型  住宅类型      血糖水平       BMI  \
0   1  0.816895        0        1       1  0.666667     1  0.801265  0.647564   
1   1  0.975586        0        1       1  0.666667     0  0.234512  0.530086   
2   0  0.597168        0        0       1  0.666667     1  0.536008  0.584527   
3   0  0.963379        1        0       1  1.000000     0  0.549349  0.286533   
4   1  0.987793        0        0       1  0.666667     1  0.605161  0.429799   

       吸烟状况  是否中风  
0  0.666667     1  
1  0.333333     1  
2  1.000000     1  
3  0.333333     1  
4  0.666667     1
               

3.数据集切分

In [15]
from sklearn.model_selection import train_test_split# 切分数据集为 训练集 、 测试集train, test = train_test_split(data, test_size=0.2, random_state=2023)
   

3.协相关

In [16]
data.corr()
       
               性别        年龄   是否患有高血压   是否患有心脏病    是否有过婚姻      工作类型      住宅类型  \
性别       1.000000 -0.026538  0.021485  0.086476 -0.028971 -0.075975 -0.004301   
年龄      -0.026538  1.000000  0.278120  0.264852  0.677137  0.583042  0.017155   
是否患有高血压  0.021485  0.278120  1.000000  0.111974  0.164534  0.140098 -0.004755   
是否患有心脏病  0.086476  0.264852  0.111974  1.000000  0.114765  0.108356  0.002125   
是否有过婚姻  -0.028971  0.677137  0.164534  0.114765  1.000000  0.455567  0.008191   
工作类型    -0.075975  0.583042  0.140098  0.108356  0.455567  1.000000  0.004053   
住宅类型    -0.004301  0.017155 -0.004755  0.002125  0.008191  0.004053  1.000000   
血糖水平     0.055796  0.236763  0.170028  0.166847  0.150724  0.100118  0.001346   
BMI     -0.012093  0.373703  0.158762  0.060926  0.371690  0.378679  0.013185   
吸烟状况    -0.000653  0.305230  0.104703  0.085429  0.287190  0.318605  0.026798   
是否中风     0.008870  0.246478  0.131965  0.134610  0.108398  0.091301  0.016494   

             血糖水平       BMI      吸烟状况      是否中风  
性别       0.055796 -0.012093 -0.000653  0.008870  
年龄       0.236763  0.373703  0.305230  0.246478  
是否患有高血压  0.170028  0.158762  0.104703  0.131965  
是否患有心脏病  0.166847  0.060926  0.085429  0.134610  
是否有过婚姻   0.150724  0.371690  0.287190  0.108398  
工作类型     0.100118  0.378679  0.318605  0.091301  
住宅类型     0.001346  0.013185  0.026798  0.016494  
血糖水平     1.000000  0.186348  0.079654  0.133227  
BMI      0.186348  1.000000  0.245660  0.056926  
吸烟状况     0.079654  0.245660  1.000000  0.054793  
是否中风     0.133227  0.056926  0.054793  1.000000
               
In [17]
import matplotlib.pyplot as plt
%matplotlib inlineimport seaborn as sns
sns.set_style('whitegrid')# 热力图plt.figure(figsize=(20,12))
sns.heatmap(train.corr(), annot=True)
       
               
               

四、模型训练

1.网络定义

In [18]
import paddleimport paddle.nn.functional as F# 定义动态图class Net(paddle.nn.Layer):
    def __init__(self):
        super(Net, self).__init__()        # 定义一层全连接层,输出维度是1,激活函数为None,即不使用激活函数
        self.fc = paddle.nn.Linear(in_features=10,out_features=2)    
    # 网络的前向计算函数
    def forward(self, inputs):
        pred = self.fc(inputs)        return pred
   
In [19]
net=Net()
   

2.超参设置

In [20]
# 设置迭代次数epochs = 6#  paddle.nn.loss.CrossEntropyLoss正常#  paddle.nn.CrossEntropyLoss不正常loss_func = paddle.nn.CrossEntropyLoss()#优化器opt = paddle.optimizer.Adam(learning_rate=0.1,parameters=net.parameters())
   

3.模型训练

In [21]
#训练程序for epoch in range(epochs):
    all_acc = 0
    for i in range(train.shape[0]):
        x = paddle.to_tensor([train.iloc[i,:-1]])
        y = paddle.to_tensor([train.iloc[i,-1]])
        infer_y = net(x)
        loss = loss_func(infer_y,y)
        loss.backward()
        y=label = paddle.to_tensor([y], dtype="int64")
        acc= paddle.metric.accuracy(infer_y, y)
        all_acc=all_acc+acc.numpy()
        opt.step()
        opt.clear_gradients#清除梯度
        # print("epoch: {}, loss is: {},acc is:{}".format(epoch, loss.numpy(),acc.numpy()))  #由于输出过长,这里注释掉了
    print("第{}次正确率为:{}".format(epoch+1,all_acc/i))
       
第1次正确率为:[0.906352]
第2次正确率为:[0.913884]
第3次正确率为:[0.9131308]
第4次正确率为:[0.9131308]
第5次正确率为:[0.9199096]
第6次正确率为:[0.9113733]
       

五、模型评估

1.评估

In [22]
#测试集数据运行net.eval()#模型转换为测试模式all_acc = 0for i in range(test.shape[0]):
        x = paddle.to_tensor([test.iloc[i,:-1]])
        y = paddle.to_tensor([test.iloc[i,-1]])        
        infer_y = net(x)
        y=label = paddle.to_tensor([y], dtype="int64")    # 计算损失与精度
        loss = loss_func(infer_y, y)
        acc = paddle.metric.accuracy(infer_y, y)
        all_acc = all_acc+acc.numpy()    # 打印信息
        #print("loss is: {}, acc is: {}".format(loss.numpy(), acc.numpy()))print("测试集正确率:{}".format(all_acc/i))
       
测试集正确率:[0.87650603]
       

2.预测

In [23]
#预测结果展示net.eval()
x = paddle.to_tensor([test.iloc[0,:-1]])
y = paddle.to_tensor([test.iloc[0,-1]])    
infer_y = net(x)
y=label = paddle.to_tensor([y], dtype="int64")# 计算损失与精度loss = loss_func(infer_y, y)# 打印信息print("test[0] is :{}\n y_test[0] is :{}\n predict is {}".format(test.iloc[0,:-1] ,test.iloc[0,-1], np.argmax(infer_y.numpy()[0])))
       
test[0] is :性别         0.000000
年龄         0.218750
是否患有高血压    0.000000
是否患有心脏病    0.000000
是否有过婚姻     0.000000
工作类型       0.666667
住宅类型       1.000000
血糖水平       0.202613
BMI        0.329513
吸烟状况       0.666667
Name: 4121, dtype: float64
 y_test[0] is :0
 predict is 0
       

相关专题

更多
数据分析的方法
数据分析的方法

数据分析的方法有:对比分析法,分组分析法,预测分析法,漏斗分析法,AB测试分析法,象限分析法,公式拆解法,可行域分析法,二八分析法,假设性分析法。php中文网为大家带来了数据分析的相关知识、以及相关文章等内容。

447

2023.07.04

数据分析方法有哪几种
数据分析方法有哪几种

数据分析方法有:1、描述性统计分析;2、探索性数据分析;3、假设检验;4、回归分析;5、聚类分析。本专题为大家提供数据分析方法的相关的文章、下载、课程内容,供大家免费下载体验。

258

2023.08.07

网站建设功能有哪些
网站建设功能有哪些

网站建设功能包括信息发布、内容管理、用户管理、搜索引擎优化、网站安全、数据分析、网站推广、响应式设计、社交媒体整合和电子商务等功能。这些功能可以帮助网站管理员创建一个具有吸引力、可用性和商业价值的网站,实现网站的目标。

716

2023.10.16

数据分析网站推荐
数据分析网站推荐

数据分析网站推荐:1、商业数据分析论坛;2、人大经济论坛-计量经济学与统计区;3、中国统计论坛;4、数据挖掘学习交流论坛;5、数据分析论坛;6、网站数据分析;7、数据分析;8、数据挖掘研究院;9、S-PLUS、R统计论坛。想了解更多数据分析的相关内容,可以阅读本专题下面的文章。

498

2024.03.13

Python 数据分析处理
Python 数据分析处理

本专题聚焦 Python 在数据分析领域的应用,系统讲解 Pandas、NumPy 的数据清洗、处理、分析与统计方法,并结合数据可视化、销售分析、科研数据处理等实战案例,帮助学员掌握使用 Python 高效进行数据分析与决策支持的核心技能。

71

2025.09.08

Python 数据分析与可视化
Python 数据分析与可视化

本专题聚焦 Python 在数据分析与可视化领域的核心应用,系统讲解数据清洗、数据统计、Pandas 数据操作、NumPy 数组处理、Matplotlib 与 Seaborn 可视化技巧等内容。通过实战案例(如销售数据分析、用户行为可视化、趋势图与热力图绘制),帮助学习者掌握 从原始数据到可视化报告的完整分析能力。

54

2025.10.14

JavaScript ES6新特性
JavaScript ES6新特性

ES6是JavaScript的根本性升级,引入let/const实现块级作用域、箭头函数解决this绑定问题、解构赋值与模板字符串简化数据处理、对象简写与模块化提升代码可读性与组织性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

0

2025.12.24

php框架基础知识汇总
php框架基础知识汇总

php框架是构建web应用程序的架构,提供工具和功能,以简化开发过程。选择合适的框架取决于项目需求和技能水平。实战案例展示了使用laravel构建博客的步骤,包括安装、创建模型、定义路由、编写控制器和呈现视图。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

1

2025.12.24

Word 字间距调整方法汇总
Word 字间距调整方法汇总

本专题整合了Word字间距调整方法,阅读下面的文章了解更详细操作。

2

2025.12.24

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.8万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号