本项目基于PaddlePaddle2.0.0rc,用2010-2014年北京空气污染数据,通过LSTM和DNN对比预测未来一天某时刻PM2.5。经数据预处理、标准化和滑窗处理,构建模型训练。结果显示,LSTM在时序预测上效果更优,验证集MAE损失更低,未出现过拟合,更适合此类任务。
☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

# !python -m pip install paddlepaddle-gpu==2.0.0rc0.post101 -f https://paddlepaddle.org.cn/whl/stable.html
import paddleimport paddle.fluid as fluidimport paddle.fluid.layers as layersimport pandas as pdimport numpy as npimport matplotlib.pyplot as plt %matplotlib inline
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/pandas/core/tools/datetimes.py:3: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized
data = pd.read_csv('data/data55547/PRSA_data_2010.1.1-2014.12.31.csv')#查看数据大小,类型及是否存在缺失值data.info()<class 'pandas.core.frame.DataFrame'> RangeIndex: 43824 entries, 0 to 43823 Data columns (total 13 columns): No 43824 non-null int64 year 43824 non-null int64 month 43824 non-null int64 day 43824 non-null int64 hour 43824 non-null int64 pm2.5 41757 non-null float64 DEWP 43824 non-null int64 TEMP 43824 non-null float64 PRES 43824 non-null float64 cbwd 43824 non-null object Iws 43824 non-null float64 Is 43824 non-null int64 Ir 43824 non-null int64 dtypes: float64(4), int64(8), object(1) memory usage: 4.3+ MB
data[data['pm2.5'].isna()]
No year month day hour pm2.5 DEWP TEMP PRES cbwd Iws \
0 1 2010 1 1 0 NaN -21 -11.0 1021.0 NW 1.79
1 2 2010 1 1 1 NaN -21 -12.0 1020.0 NW 4.92
2 3 2010 1 1 2 NaN -21 -11.0 1019.0 NW 6.71
3 4 2010 1 1 3 NaN -21 -14.0 1019.0 NW 9.84
4 5 2010 1 1 4 NaN -20 -12.0 1018.0 NW 12.97
5 6 2010 1 1 5 NaN -19 -10.0 1017.0 NW 16.10
6 7 2010 1 1 6 NaN -19 -9.0 1017.0 NW 19.23
7 8 2010 1 1 7 NaN -19 -9.0 1017.0 NW 21.02
8 9 2010 1 1 8 NaN -19 -9.0 1017.0 NW 24.15
9 10 2010 1 1 9 NaN -20 -8.0 1017.0 NW 27.28
10 11 2010 1 1 10 NaN -19 -7.0 1017.0 NW 31.30
11 12 2010 1 1 11 NaN -18 -5.0 1017.0 NW 34.43
12 13 2010 1 1 12 NaN -19 -5.0 1015.0 NW 37.56
13 14 2010 1 1 13 NaN -18 -3.0 1015.0 NW 40.69
14 15 2010 1 1 14 NaN -18 -2.0 1014.0 NW 43.82
15 16 2010 1 1 15 NaN -18 -1.0 1014.0 cv 0.89
16 17 2010 1 1 16 NaN -19 -2.0 1015.0 NW 1.79
17 18 2010 1 1 17 NaN -18 -3.0 1015.0 NW 2.68
18 19 2010 1 1 18 NaN -18 -5.0 1016.0 NE 1.79
19 20 2010 1 1 19 NaN -17 -4.0 1017.0 NW 1.79
20 21 2010 1 1 20 NaN -17 -5.0 1017.0 cv 0.89
21 22 2010 1 1 21 NaN -17 -5.0 1018.0 NW 1.79
22 23 2010 1 1 22 NaN -17 -5.0 1018.0 NW 2.68
23 24 2010 1 1 23 NaN -17 -5.0 1020.0 cv 0.89
545 546 2010 1 23 17 NaN -18 2.0 1024.0 NW 91.22
546 547 2010 1 23 18 NaN -18 1.0 1024.0 NW 96.14
547 548 2010 1 23 19 NaN -17 0.0 1024.0 NW 100.16
548 549 2010 1 23 20 NaN -18 0.0 1024.0 SE 1.79
549 550 2010 1 23 21 NaN -15 -3.0 1024.0 cv 0.89
550 551 2010 1 23 22 NaN -16 0.0 1023.0 NW 1.79
... ... ... ... ... ... ... ... ... ... ... ...
42847 42848 2014 11 21 7 NaN -3 0.0 1020.0 NW 11.17
42848 42849 2014 11 21 8 NaN -3 0.0 1020.0 cv 0.89
43190 43191 2014 12 5 14 NaN -22 4.0 1025.0 NW 41.12
43191 43192 2014 12 5 15 NaN -22 3.0 1025.0 NE 1.79
43264 43265 2014 12 8 16 NaN -13 3.0 1033.0 cv 1.79
43266 43267 2014 12 8 18 NaN -11 -2.0 1034.0 SE 0.89
43267 43268 2014 12 8 19 NaN -11 -2.0 1035.0 SE 1.78
43268 43269 2014 12 8 20 NaN -11 -4.0 1036.0 SE 2.67
43269 43270 2014 12 8 21 NaN -11 -5.0 1036.0 SE 3.56
43270 43271 2014 12 8 22 NaN -11 -5.0 1036.0 NE 0.89
43273 43274 2014 12 9 1 NaN -11 -4.0 1037.0 cv 0.89
43274 43275 2014 12 9 2 NaN -10 -5.0 1036.0 SE 0.89
43275 43276 2014 12 9 3 NaN -10 -6.0 1037.0 cv 0.89
43276 43277 2014 12 9 4 NaN -10 -7.0 1036.0 cv 1.78
43277 43278 2014 12 9 5 NaN -11 -6.0 1036.0 cv 2.67
43278 43279 2014 12 9 6 NaN -11 -7.0 1036.0 cv 3.56
43279 43280 2014 12 9 7 NaN -11 -8.0 1036.0 cv 4.45
43280 43281 2014 12 9 8 NaN -9 -6.0 1036.0 SE 0.89
43281 43282 2014 12 9 9 NaN -8 -5.0 1037.0 NE 1.79
43282 43283 2014 12 9 10 NaN -8 -4.0 1037.0 cv 0.89
43283 43284 2014 12 9 11 NaN -8 -3.0 1036.0 NE 1.79
43544 43545 2014 12 20 8 NaN -18 -4.0 1031.0 NW 225.30
43545 43546 2014 12 20 9 NaN -17 -4.0 1031.0 NW 228.43
43546 43547 2014 12 20 10 NaN -18 -2.0 1031.0 NW 233.35
43547 43548 2014 12 20 11 NaN -17 -1.0 1031.0 NW 239.16
43548 43549 2014 12 20 12 NaN -18 0.0 1030.0 NW 244.97
43549 43550 2014 12 20 13 NaN -19 1.0 1029.0 NW 249.89
43550 43551 2014 12 20 14 NaN -20 1.0 1029.0 NW 257.04
43551 43552 2014 12 20 15 NaN -20 2.0 1028.0 NW 262.85
43552 43553 2014 12 20 16 NaN -21 1.0 1028.0 NW 270.00
Is Ir
0 0 0
1 0 0
2 0 0
3 0 0
4 0 0
5 0 0
6 0 0
7 0 0
8 0 0
9 0 0
10 0 0
11 0 0
12 0 0
13 0 0
14 0 0
15 0 0
16 0 0
17 0 0
18 0 0
19 0 0
20 0 0
21 0 0
22 0 0
23 0 0
545 0 0
546 0 0
547 0 0
548 0 0
549 0 0
550 0 0
... .. ..
42847 0 0
42848 0 0
43190 0 0
43191 0 0
43264 0 0
43266 0 0
43267 0 0
43268 0 0
43269 0 0
43270 0 0
43273 0 0
43274 0 0
43275 0 0
43276 0 0
43277 0 0
43278 0 0
43279 0 0
43280 0 0
43281 0 0
43282 0 0
43283 0 0
43544 0 0
43545 0 0
43546 0 0
43547 0 0
43548 0 0
43549 0 0
43550 0 0
43551 0 0
43552 0 0
[2067 rows x 13 columns]data = data.iloc[24:].copy()#由于0-24之间为nan,我们采用'ffill'的填充方法,由于0-24前面无数据,无法实现填充data.fillna(method='ffill', inplace=True) data.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 43800 entries, 24 to 43823 Data columns (total 13 columns): No 43800 non-null int64 year 43800 non-null int64 month 43800 non-null int64 day 43800 non-null int64 hour 43800 non-null int64 pm2.5 43800 non-null float64 DEWP 43800 non-null int64 TEMP 43800 non-null float64 PRES 43800 non-null float64 cbwd 43800 non-null object Iws 43800 non-null float64 Is 43800 non-null int64 Ir 43800 non-null int64 dtypes: float64(4), int64(8), object(1) memory usage: 4.3+ MB
data.drop('No', axis=1, inplace=True)import datetime
data['time'] = data.apply(lambda x: datetime.datetime(year=x['year'],
month=x['month'],
day=x['day'],
hour=x['hour']),
axis=1)data.set_index('time', inplace=True)
data.drop(columns=['year', 'month', 'day', 'hour'], inplace=True)
data.head()pm2.5 DEWP TEMP PRES cbwd Iws Is Ir time 2010-01-02 00:00:00 129.0 -16 -4.0 1020.0 SE 1.79 0 0 2010-01-02 01:00:00 148.0 -15 -4.0 1020.0 SE 2.68 0 0 2010-01-02 02:00:00 159.0 -11 -5.0 1021.0 SE 3.57 0 0 2010-01-02 03:00:00 181.0 -7 -5.0 1022.0 SE 5.36 1 0 2010-01-02 04:00:00 138.0 -7 -5.0 1022.0 SE 6.25 2 0
data.columns = ['pm2.5', 'dew', 'temp', 'press', 'cbwd', 'iws', 'snow', 'rain']
data.cbwd.unique()
array(['SE', 'cv', 'NW', 'NE'], dtype=object)
del data['cbwd']
data.info()
<class 'pandas.core.frame.DataFrame'> DatetimeIndex: 43800 entries, 2010-01-02 00:00:00 to 2014-12-31 23:00:00 Data columns (total 7 columns): pm2.5 43800 non-null float64 dew 43800 non-null int64 temp 43800 non-null float64 press 43800 non-null float64 iws 43800 non-null float64 snow 43800 non-null int64 rain 43800 non-null int64 dtypes: float64(4), int64(3) memory usage: 2.7 MB
data['pm2.5'][-1000:].plot()
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data
<matplotlib.axes._subplots.AxesSubplot at 0x7f4561ec5c10>
<Figure size 432x288 with 1 Axes>
data['temp'][-1000:].plot()
<matplotlib.axes._subplots.AxesSubplot at 0x7f45621db250>
<Figure size 432x288 with 1 Axes>
#查看数据data.head(3)
pm2.5 dew temp press iws snow rain time 2010-01-02 00:00:00 129.0 -16 -4.0 1020.0 1.79 0 0 2010-01-02 01:00:00 148.0 -15 -4.0 1020.0 2.68 0 0 2010-01-02 02:00:00 159.0 -11 -5.0 1021.0 3.57 0 0
思考:为什么label不需要标准化呢? 从标准化的目的的角度来看,标准化是为了使得不同特征的数据规范到一个统一的范围,有利于神经网络的反向传播(假如不进行规范化,神经网络可能会刻意捕捉不同批次数据的变化,而忽视了预测任务本身),label也可以做标准化,但是预测时还得反向推导出需要的结果。
sequence_length = 5*24delay = 24# Generated training sequences for use in the model.def create_sequences(values, time_steps=sequence_length+delay):
output = [] for i in range(len(values) - time_steps):
output.append(values[i : (i + time_steps)]) return np.stack(output)
data_ = create_sequences(data.values)print("Training input shape: ", data_.shape)Training input shape: (43656, 144, 7)
split_boundary = int(data_.shape[0] * 0.8) train = data_[: split_boundary] test = data_[split_boundary:] mean = train.mean(axis=0) std = train.std(axis=0) train = (train - mean)/std test = (test - mean)/std train.shape,test.shape
((34924, 144, 7), (8732, 144, 7))
#数据生成器def switch_reader(is_val: bool = False):
def reader():
# 判断是否是验证集
if is_val: # 抽取数据使用迭代器返回
for te in test: yield te[:sequence_length],te[-1:][:,0] else: # 抽取数据使用迭代器返回
for tr in train: yield tr[:sequence_length],tr[-1:][:,0]#只取第0列pm2.5的值为label
return reader # 注意!此处不需要带括号# 划分batchbatch_size = 128train_reader = fluid.io.batch(reader=switch_reader(), batch_size=batch_size)
val_reader = fluid.io.batch(reader=switch_reader(is_val=True), batch_size=batch_size)for data in train_reader(): # print(data[0].shape,data[1].shape)
train_x=np.array([x[0] for x in data],np.float32)
train_y = np.array([x[1] for x in data]).astype('int64') print(train_x.shape,train_y.shape)#定义DNN网络class MyModel(fluid.dygraph.Layer):
'''
DNN网络
'''
def __init__(self):
super(MyModel,self).__init__()
self.fc1=fluid.dygraph.Linear(5*24*7,32,act='relu')
self.fc2=fluid.dygraph.Linear(32,1)
def forward(self,input): # forward 定义执行实际运行时网络的执行逻辑
'''前向计算'''
# print('input',input.shape)
input =fluid.layers.reshape(input,shape=[-1,5*24*7])
out=self.fc1(input)
out=self.fc2(out) # print(out.shape)
return outBatch=0Batchs=[]
all_train_loss=[]def draw_train_loss(Batchs, train_loss,eval_loss):
title="training-eval loss"
plt.title(title, fontsize=24)
plt.xlabel("batch", fontsize=14)
plt.ylabel("loss", fontsize=14)
plt.plot(Batchs, train_loss, color='red', label='training loss')
plt.plot(Batchs, eval_loss, color='g', label='eval loss')
plt.legend()
plt.grid()
plt.show()# place = fluid.CUDAPlace(0) #非develop版本请勿使用GPU版本place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model=MyModel() #模型实例化
model.train() #训练模式
# opt=fluid.optimizer.SGDOptimizer(learning_rate=train_parameters['learning_strategy']['lr'], parameter_list=model.parameters())#优化器选用SGD随机梯度下降,学习率为0.001.
opt=fluid.optimizer.AdamOptimizer(learning_rate=0.0001, parameter_list=model.parameters())
epochs_num=100#迭代次数
batch_size = 128*16
train_reader = fluid.io.batch(reader=switch_reader(), batch_size=batch_size)
val_reader = fluid.io.batch(reader=switch_reader(is_val=True), batch_size=batch_size)
Batch=0
Batchs=[]
all_train_loss=[]
all_eval_loss=[] for pass_num in range(epochs_num): for batch_id, data in enumerate(train_reader()):
data_x=np.array([x[0] for x in data],np.float32)
data_y = np.array([x[1] for x in data]).astype('float32')
data_x = fluid.dygraph.to_variable(data_x)
data_y = fluid.dygraph.to_variable(data_y)
# print(data_x.shape, data_y.shape)
predict=model(data_x) # print(predict.shape)
loss=fluid.layers.mse_loss(predict,data_y)
avg_loss=fluid.layers.mean(loss)#获取loss值
avg_loss.backward()
opt.minimize(avg_loss) #优化器对象的minimize方法对参数进行更新
model.clear_gradients() #model.clear_gradients()来重置梯度
if batch_id!=0 and batch_id%10==0:
Batch = Batch+10
Batchs.append(Batch)
all_train_loss.append(avg_loss.numpy()[0])
evalavg_loss=[] for eval_data in val_reader():
eval_data_x = np.array([x[0] for x in eval_data],np.float32)
eval_data_y = np.array([x[1] for x in eval_data]).astype('float32')
eval_data_x = fluid.dygraph.to_variable(eval_data_x)
eval_data_y = fluid.dygraph.to_variable(eval_data_y)
eval_predict=model(eval_data_x)
eval_loss=fluid.layers.mse_loss(eval_predict,eval_data_y)
eval_loss=fluid.layers.mean(eval_loss)
evalavg_loss.append(eval_loss.numpy()[0])#获取loss值
all_eval_loss.append(sum(evalavg_loss)/len(evalavg_loss)) print("epoch:{},batch_id:{},train_loss:{},eval_loss:{}".format(pass_num,batch_id,avg_loss.numpy(),sum(evalavg_loss)/len(evalavg_loss)))
fluid.save_dygraph(model.state_dict(),'MyModel')#保存模型
fluid.save_dygraph(opt.state_dict(),'MyModel')#保存模型
print("Final loss: {}".format(avg_loss.numpy()))
#让我们绘制训练图和验证损失图,以了解训练的进行情况。 draw_train_loss(Batchs,all_train_loss,all_eval_loss)epoch:0,batch_id:17,train_loss:[2.0205],eval_loss:1.4090836882591247 epoch:1,batch_id:17,train_loss:[0.95689076],eval_loss:1.3361332535743713 epoch:2,batch_id:17,train_loss:[0.7040673],eval_loss:1.2218480825424194 epoch:3,batch_id:17,train_loss:[0.55722934],eval_loss:1.1956807255744935 epoch:4,batch_id:17,train_loss:[0.44944313],eval_loss:1.1633899331092834 epoch:5,batch_id:17,train_loss:[0.37596697],eval_loss:1.1420036435127259 epoch:6,batch_id:17,train_loss:[0.31873935],eval_loss:1.1268895626068116 epoch:7,batch_id:17,train_loss:[0.27411735],eval_loss:1.1125162959098815 epoch:8,batch_id:17,train_loss:[0.2403403],eval_loss:1.1013256669044496 epoch:9,batch_id:17,train_loss:[0.21393616],eval_loss:1.0918826699256896 epoch:10,batch_id:17,train_loss:[0.19293499],eval_loss:1.0833844304084779 epoch:11,batch_id:17,train_loss:[0.17653875],eval_loss:1.076257163286209 epoch:12,batch_id:17,train_loss:[0.16238903],eval_loss:1.0695580899715424 epoch:13,batch_id:17,train_loss:[0.15100743],eval_loss:1.0639597535133363 epoch:14,batch_id:17,train_loss:[0.14186515],eval_loss:1.0592819035053254 epoch:15,batch_id:17,train_loss:[0.13479522],eval_loss:1.055191159248352 epoch:16,batch_id:17,train_loss:[0.1290898],eval_loss:1.051255214214325 epoch:17,batch_id:17,train_loss:[0.12424426],eval_loss:1.047574871778488 epoch:18,batch_id:17,train_loss:[0.11999645],eval_loss:1.0441474676132203 epoch:19,batch_id:17,train_loss:[0.11639561],eval_loss:1.0410736680030823 epoch:20,batch_id:17,train_loss:[0.11316744],eval_loss:1.038214284181595 epoch:21,batch_id:17,train_loss:[0.11018123],eval_loss:1.0353180170059204 epoch:22,batch_id:17,train_loss:[0.10779685],eval_loss:1.032786226272583 epoch:23,batch_id:17,train_loss:[0.10557291],eval_loss:1.0302606165409087 epoch:24,batch_id:17,train_loss:[0.1037445],eval_loss:1.0279349327087401 epoch:25,batch_id:17,train_loss:[0.10192361],eval_loss:1.025689673423767 epoch:26,batch_id:17,train_loss:[0.10021695],eval_loss:1.023529589176178 epoch:27,batch_id:17,train_loss:[0.0984721],eval_loss:1.0216342866420747 epoch:28,batch_id:17,train_loss:[0.09707484],eval_loss:1.019782018661499 epoch:29,batch_id:17,train_loss:[0.0957087],eval_loss:1.0179351627826692 epoch:30,batch_id:17,train_loss:[0.09425645],eval_loss:1.0161666870117188 epoch:31,batch_id:17,train_loss:[0.09265903],eval_loss:1.0144980549812317 epoch:32,batch_id:17,train_loss:[0.09125529],eval_loss:1.0129337430000305 epoch:33,batch_id:17,train_loss:[0.08980759],eval_loss:1.0113472878932952 epoch:34,batch_id:17,train_loss:[0.08829899],eval_loss:1.0098000168800354 epoch:35,batch_id:17,train_loss:[0.08709818],eval_loss:1.0085288822650909 epoch:36,batch_id:17,train_loss:[0.08586626],eval_loss:1.0073016047477723 epoch:37,batch_id:17,train_loss:[0.08476242],eval_loss:1.0060584604740144 epoch:38,batch_id:17,train_loss:[0.08362537],eval_loss:1.0050600707530974 epoch:39,batch_id:17,train_loss:[0.08273471],eval_loss:1.0040717482566834 epoch:40,batch_id:17,train_loss:[0.08195919],eval_loss:1.0030322968959808 epoch:41,batch_id:17,train_loss:[0.0810699],eval_loss:1.0020980298519135 epoch:42,batch_id:17,train_loss:[0.07989511],eval_loss:1.0009490311145783 epoch:43,batch_id:17,train_loss:[0.07878471],eval_loss:0.9999779641628266 epoch:44,batch_id:17,train_loss:[0.07754707],eval_loss:0.9990507960319519 epoch:45,batch_id:17,train_loss:[0.07625636],eval_loss:0.997998195886612 epoch:46,batch_id:17,train_loss:[0.07513986],eval_loss:0.9971686065196991 epoch:47,batch_id:17,train_loss:[0.07390005],eval_loss:0.9962048828601837 epoch:48,batch_id:17,train_loss:[0.07286156],eval_loss:0.9953225016593933 epoch:49,batch_id:17,train_loss:[0.07175022],eval_loss:0.9946246147155762 epoch:50,batch_id:17,train_loss:[0.07077469],eval_loss:0.993957793712616 epoch:51,batch_id:17,train_loss:[0.06977923],eval_loss:0.993251645565033 epoch:52,batch_id:17,train_loss:[0.06907593],eval_loss:0.992446219921112 epoch:53,batch_id:17,train_loss:[0.06824756],eval_loss:0.991847711801529 epoch:54,batch_id:17,train_loss:[0.06763344],eval_loss:0.9912112653255463 epoch:55,batch_id:17,train_loss:[0.06695005],eval_loss:0.9905830025672913 epoch:56,batch_id:17,train_loss:[0.06627547],eval_loss:0.9900696039199829 epoch:57,batch_id:17,train_loss:[0.06573104],eval_loss:0.9896724104881287 epoch:58,batch_id:17,train_loss:[0.06506079],eval_loss:0.9892310202121735 epoch:59,batch_id:17,train_loss:[0.06436179],eval_loss:0.9887569844722748 epoch:60,batch_id:17,train_loss:[0.06374478],eval_loss:0.9883864879608154 epoch:61,batch_id:17,train_loss:[0.06303963],eval_loss:0.9881407439708709 epoch:62,batch_id:17,train_loss:[0.06245909],eval_loss:0.9878709852695465 epoch:63,batch_id:17,train_loss:[0.06174919],eval_loss:0.9875110030174256 epoch:64,batch_id:17,train_loss:[0.06118464],eval_loss:0.987206107378006 epoch:65,batch_id:17,train_loss:[0.06051154],eval_loss:0.9869666278362275 epoch:66,batch_id:17,train_loss:[0.05986768],eval_loss:0.9865923523902893 epoch:67,batch_id:17,train_loss:[0.05928758],eval_loss:0.9863128185272216 epoch:68,batch_id:17,train_loss:[0.05866254],eval_loss:0.9859303057193756 epoch:69,batch_id:17,train_loss:[0.05802014],eval_loss:0.9856755137443542 epoch:70,batch_id:17,train_loss:[0.0575587],eval_loss:0.9854108214378356 epoch:71,batch_id:17,train_loss:[0.05704111],eval_loss:0.985070925951004 epoch:72,batch_id:17,train_loss:[0.05671573],eval_loss:0.9848090887069703 epoch:73,batch_id:17,train_loss:[0.05617322],eval_loss:0.9845478892326355 epoch:74,batch_id:17,train_loss:[0.05566153],eval_loss:0.9842156410217285 epoch:75,batch_id:17,train_loss:[0.05529902],eval_loss:0.9840305268764495 epoch:76,batch_id:17,train_loss:[0.05462031],eval_loss:0.9837329030036926 epoch:77,batch_id:17,train_loss:[0.05434851],eval_loss:0.9835087835788727 epoch:78,batch_id:17,train_loss:[0.05377433],eval_loss:0.9832845091819763 epoch:79,batch_id:17,train_loss:[0.05343863],eval_loss:0.9830455482006073 epoch:80,batch_id:17,train_loss:[0.05288152],eval_loss:0.982842218875885 epoch:81,batch_id:17,train_loss:[0.05258711],eval_loss:0.982667338848114 epoch:82,batch_id:17,train_loss:[0.05217287],eval_loss:0.9824033558368683 epoch:83,batch_id:17,train_loss:[0.05160918],eval_loss:0.9821954727172851 epoch:84,batch_id:17,train_loss:[0.05129151],eval_loss:0.9820389568805694 epoch:85,batch_id:17,train_loss:[0.05077891],eval_loss:0.9820009410381317 epoch:86,batch_id:17,train_loss:[0.05045455],eval_loss:0.9819312691688538 epoch:87,batch_id:17,train_loss:[0.04997],eval_loss:0.9818430423736573 epoch:88,batch_id:17,train_loss:[0.04965632],eval_loss:0.9816549181938171 epoch:89,batch_id:17,train_loss:[0.04909806],eval_loss:0.9816236138343811 epoch:90,batch_id:17,train_loss:[0.04883103],eval_loss:0.9815687894821167 epoch:91,batch_id:17,train_loss:[0.04832352],eval_loss:0.9815601170063019 epoch:92,batch_id:17,train_loss:[0.04800665],eval_loss:0.9814506828784942 epoch:93,batch_id:17,train_loss:[0.04761852],eval_loss:0.9812910079956054 epoch:94,batch_id:17,train_loss:[0.04736731],eval_loss:0.9812990665435791 epoch:95,batch_id:17,train_loss:[0.04682],eval_loss:0.9812341630458832 epoch:96,batch_id:17,train_loss:[0.04646796],eval_loss:0.9810558021068573 epoch:97,batch_id:17,train_loss:[0.04601882],eval_loss:0.9810874044895173 epoch:98,batch_id:17,train_loss:[0.04565503],eval_loss:0.9811647534370422 epoch:99,batch_id:17,train_loss:[0.04528417],eval_loss:0.9811281561851501 Final loss: [0.04528417]
<Figure size 432x288 with 1 Axes>
状态更新公式如下:
inputs (Tensor) - 网络输入。如果time_major为True,则Tensor的形状为[time_steps,batch_size,input_size],如果time_major为False,则Tensor的形状为[batch_size,time_steps,input_size]。
initial_states (tuple,可选) - 网络的初始状态,一个包含h和c的元组,形状为[num_lauers * num_directions, batch_size, hidden_size]。如果没有给出则会以全零初始化。
sequence_length (Tensor,可选) - 指定输入序列的长度,形状为[batch_size],数据类型为int64或int32。在输入序列中所有time step不小于sequence_length的元素都会被当作填充元素处理(状态不再更新)。
outputs (Tensor) - 输出,由前向和后向cell的输出拼接得到。如果time_major为True,则Tensor的形状为[time_steps,batch_size,num_directions * hidden_size],如果time_major为False,则Tensor的形状为[batch_size,time_steps,num_directions * hidden_size],当direction设置为bidirectional时,num_directions等于2,否则等于1。
final_states (tuple) - 最终状态,一个包含h和c的元组。形状为[num_lauers * num_directions, batch_size, hidden_size],当direction设置为bidirectional时,num_directions等于2,否则等于1。
#定义LSTM网络import paddle.fluid as fluidclass MyLSTMModel(fluid.dygraph.Layer):
'''
DNN网络
'''
def __init__(self):
super(MyLSTMModel,self).__init__()
self.rnn = paddle.nn.LSTM(7, 14, 2)
self.flatten = paddle.nn.Flatten()
self.fc1=fluid.dygraph.Linear(120*14,120)
self.fc2=fluid.dygraph.Linear(120,1)
def forward(self,input): # forward 定义执行实际运行时网络的执行逻辑
'''前向计算'''
# print('input',input.shape)
out, (h, c)=self.rnn(input)
out =self.flatten(out)
out=self.fc1(out)
out=self.fc2(out) return outBatch=0Batchs=[]
all_train_loss=[]def draw_train_loss(Batchs, train_loss,eval_loss):
title="training-eval loss"
plt.title(title, fontsize=24)
plt.xlabel("batch", fontsize=14)
plt.ylabel("loss", fontsize=14)
plt.plot(Batchs, train_loss, color='red', label='training loss')
plt.plot(Batchs, eval_loss, color='g', label='eval loss')
plt.legend()
plt.grid()
plt.show()import paddle# place = fluid.CUDAPlace(0) #非develop版本请勿使用GPU版本place = fluid.CPUPlace()with fluid.dygraph.guard(place):
model=MyLSTMModel() #模型实例化
# model=MyModel()
model.train() #训练模式
# opt=fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters())#优化器选用SGD随机梯度下降,学习率为0.001.
opt=fluid.optimizer.AdamOptimizer(learning_rate=0.01, parameter_list=model.parameters())
epochs_num=100#迭代次数
batch_size = 128*32
train_reader = fluid.io.batch(reader=switch_reader(), batch_size=batch_size)
val_reader = fluid.io.batch(reader=switch_reader(is_val=True), batch_size=batch_size)
Batch=0
Batchs=[]
all_train_loss=[]
all_eval_loss=[] for pass_num in range(epochs_num): for batch_id, data in enumerate(train_reader()):
data_x=np.array([x[0] for x in data],np.float32)
data_y = np.array([x[1] for x in data]).astype('float32')
data_x = fluid.dygraph.to_variable(data_x)
data_y = fluid.dygraph.to_variable(data_y)
# print(data_x.shape, data_y.shape)
predict=model(data_x) # print(predict.shape)
loss=fluid.layers.mse_loss(predict,data_y)
avg_loss=fluid.layers.mean(loss)#获取loss值
avg_loss.backward()
opt.minimize(avg_loss) #优化器对象的minimize方法对参数进行更新
model.clear_gradients() #model.clear_gradients()来重置梯度
if batch_id!=0 and batch_id%1==0:
Batch = Batch+1
Batchs.append(Batch)
all_train_loss.append(avg_loss.numpy()[0])
evalavg_loss=[] for eval_data in val_reader():
eval_data_x = np.array([x[0] for x in eval_data],np.float32)
eval_data_y = np.array([x[1] for x in eval_data]).astype('float32')
eval_data_x = fluid.dygraph.to_variable(eval_data_x)
eval_data_y = fluid.dygraph.to_variable(eval_data_y)
eval_predict=model(eval_data_x)
eval_loss=fluid.layers.mse_loss(eval_predict,eval_data_y)
eval_loss=fluid.layers.mean(eval_loss)
evalavg_loss.append(eval_loss.numpy()[0])#获取loss值
all_eval_loss.append(sum(evalavg_loss)/len(evalavg_loss)) print("epoch:{},batch_id:{},train_loss:{},eval_loss:{}".format(pass_num,batch_id,avg_loss.numpy(),sum(evalavg_loss)/len(evalavg_loss)))
fluid.save_dygraph(model.state_dict(),'MyLSTMModel')#保存模型
fluid.save_dygraph(opt.state_dict(),'MyLSTMModel')#保存模型
print("Final loss: {}".format(avg_loss.numpy()))
#让我们绘制训练图和验证损失图,以了解训练的进行情况。 draw_train_loss(Batchs,all_train_loss,all_eval_loss)epoch:0,batch_id:8,train_loss:[41.62476],eval_loss:13.937688509623209 epoch:1,batch_id:8,train_loss:[4.161157],eval_loss:2.6484082142512 epoch:2,batch_id:8,train_loss:[2.1240506],eval_loss:1.698279857635498 epoch:3,batch_id:8,train_loss:[1.1397613],eval_loss:1.2127376794815063 epoch:4,batch_id:8,train_loss:[1.1065184],eval_loss:1.201335072517395 epoch:5,batch_id:8,train_loss:[1.1207557],eval_loss:1.1899906992912292 epoch:6,batch_id:8,train_loss:[1.126892],eval_loss:1.1028050978978474 epoch:7,batch_id:8,train_loss:[1.1262866],eval_loss:1.0896229942639668 epoch:8,batch_id:8,train_loss:[1.1331279],eval_loss:1.1011923948923747 epoch:9,batch_id:8,train_loss:[1.1255071],eval_loss:1.101571758588155 epoch:10,batch_id:8,train_loss:[1.1172327],eval_loss:1.0972675879796345 epoch:11,batch_id:8,train_loss:[1.1123648],eval_loss:1.0952287912368774 epoch:12,batch_id:8,train_loss:[1.1086842],eval_loss:1.0921181639035542 epoch:13,batch_id:8,train_loss:[1.1045169],eval_loss:1.086412250995636 epoch:14,batch_id:8,train_loss:[1.1000217],eval_loss:1.0816428860028584 epoch:15,batch_id:8,train_loss:[1.0957059],eval_loss:1.0777405301729839 epoch:16,batch_id:8,train_loss:[1.091319],eval_loss:1.073056121667226 epoch:17,batch_id:8,train_loss:[1.0871797],eval_loss:1.0684852600097656 epoch:18,batch_id:8,train_loss:[1.0834234],eval_loss:1.0644978284835815 epoch:19,batch_id:8,train_loss:[1.0798335],eval_loss:1.0606069564819336 epoch:20,batch_id:8,train_loss:[1.0764899],eval_loss:1.056948721408844 epoch:21,batch_id:8,train_loss:[1.0734138],eval_loss:1.05355566740036 epoch:22,batch_id:8,train_loss:[1.0705017],eval_loss:1.0503225127855937 epoch:23,batch_id:8,train_loss:[1.0677806],eval_loss:1.0473219752311707 epoch:24,batch_id:8,train_loss:[1.0652552],eval_loss:1.0444998741149902 epoch:25,batch_id:8,train_loss:[1.0628968],eval_loss:1.0418291091918945 epoch:26,batch_id:8,train_loss:[1.0606785],eval_loss:1.0393112301826477 epoch:27,batch_id:8,train_loss:[1.058571],eval_loss:1.0369138320287068 epoch:28,batch_id:8,train_loss:[1.0565668],eval_loss:1.0346330006917317 epoch:29,batch_id:8,train_loss:[1.0546503],eval_loss:1.03245347738266 epoch:30,batch_id:8,train_loss:[1.0528067],eval_loss:1.0303666790326436 epoch:31,batch_id:8,train_loss:[1.0510274],eval_loss:1.0283629894256592 epoch:32,batch_id:8,train_loss:[1.0493041],eval_loss:1.026433030764262 epoch:33,batch_id:8,train_loss:[1.0476311],eval_loss:1.0245701869328816 epoch:34,batch_id:8,train_loss:[1.0460036],eval_loss:1.0227669874827068 epoch:35,batch_id:8,train_loss:[1.0444185],eval_loss:1.0210176308949788 epoch:36,batch_id:8,train_loss:[1.0428716],eval_loss:1.0193172097206116 epoch:37,batch_id:8,train_loss:[1.041361],eval_loss:1.0176609953244526 epoch:38,batch_id:8,train_loss:[1.039884],eval_loss:1.016045093536377 epoch:39,batch_id:8,train_loss:[1.0384375],eval_loss:1.0144659479459126 epoch:40,batch_id:8,train_loss:[1.0370196],eval_loss:1.0129202802975972 epoch:41,batch_id:8,train_loss:[1.0356268],eval_loss:1.0114047129948933 epoch:42,batch_id:8,train_loss:[1.0342562],eval_loss:1.0099159677823384 epoch:43,batch_id:8,train_loss:[1.0329046],eval_loss:1.0084505478541057 epoch:44,batch_id:8,train_loss:[1.0315686],eval_loss:1.0070040822029114 epoch:45,batch_id:8,train_loss:[1.030245],eval_loss:1.005572259426117 epoch:46,batch_id:8,train_loss:[1.02893],eval_loss:1.0041507482528687 epoch:47,batch_id:8,train_loss:[1.027621],eval_loss:1.0027351379394531 epoch:48,batch_id:8,train_loss:[1.0263156],eval_loss:1.0013217131296794 epoch:49,batch_id:8,train_loss:[1.0250111],eval_loss:0.9999080300331116 epoch:50,batch_id:8,train_loss:[1.0237058],eval_loss:0.9984927773475647 epoch:51,batch_id:8,train_loss:[1.0223969],eval_loss:0.9970751603444418 epoch:52,batch_id:8,train_loss:[1.0210828],eval_loss:0.9956548611323038 epoch:53,batch_id:8,train_loss:[1.0197608],eval_loss:0.9942313432693481 epoch:54,batch_id:8,train_loss:[1.0184289],eval_loss:0.9928037524223328 epoch:55,batch_id:8,train_loss:[1.0170839],eval_loss:0.9913713534673055 epoch:56,batch_id:8,train_loss:[1.0157228],eval_loss:0.9899326960245768 epoch:57,batch_id:8,train_loss:[1.0143429],eval_loss:0.9884872436523438 epoch:58,batch_id:8,train_loss:[1.0129406],eval_loss:0.987034797668457 epoch:59,batch_id:8,train_loss:[1.0115134],eval_loss:0.9855763912200928 epoch:60,batch_id:8,train_loss:[1.0100583],eval_loss:0.9841130574544271 epoch:61,batch_id:8,train_loss:[1.0085737],eval_loss:0.9826481342315674 epoch:62,batch_id:8,train_loss:[1.0070602],eval_loss:0.9811853369077047 epoch:63,batch_id:8,train_loss:[1.005519],eval_loss:0.9797286987304688 epoch:64,batch_id:8,train_loss:[1.0039535],eval_loss:0.978282650311788 epoch:65,batch_id:8,train_loss:[1.002368],eval_loss:0.9768513441085815 epoch:66,batch_id:8,train_loss:[1.0007681],eval_loss:0.9754383365313212 epoch:67,batch_id:8,train_loss:[0.99915963],eval_loss:0.9740457932154337 epoch:68,batch_id:8,train_loss:[0.9975485],eval_loss:0.9726754426956177 epoch:69,batch_id:8,train_loss:[0.9959406],eval_loss:0.9713284373283386 epoch:70,batch_id:8,train_loss:[0.99434185],eval_loss:0.9700064063072205 epoch:71,batch_id:8,train_loss:[0.9927588],eval_loss:0.9687100450197855 epoch:72,batch_id:8,train_loss:[0.9911981],eval_loss:0.9674400091171265 epoch:73,batch_id:8,train_loss:[0.9896648],eval_loss:0.9661963979403178 epoch:74,batch_id:8,train_loss:[0.9881637],eval_loss:0.9649792909622192 epoch:75,batch_id:8,train_loss:[0.9866975],eval_loss:0.963790496190389 epoch:76,batch_id:8,train_loss:[0.9852681],eval_loss:0.9626333912213644 epoch:77,batch_id:8,train_loss:[0.98387516],eval_loss:0.9615116119384766 epoch:78,batch_id:8,train_loss:[0.98251814],eval_loss:0.9604288736979166 epoch:79,batch_id:8,train_loss:[0.9811964],eval_loss:0.9593873818715414 epoch:80,batch_id:8,train_loss:[0.9799079],eval_loss:0.958388884862264 epoch:81,batch_id:8,train_loss:[0.9786506],eval_loss:0.9574349522590637 epoch:82,batch_id:8,train_loss:[0.9774228],eval_loss:0.9565259019533793 epoch:83,batch_id:8,train_loss:[0.97622156],eval_loss:0.9556620121002197 epoch:84,batch_id:8,train_loss:[0.9750451],eval_loss:0.9548425475756327 epoch:85,batch_id:8,train_loss:[0.9738902],eval_loss:0.954066793123881 epoch:86,batch_id:8,train_loss:[0.9727558],eval_loss:0.9533333977063497 epoch:87,batch_id:8,train_loss:[0.9716397],eval_loss:0.9526411096254984 epoch:88,batch_id:8,train_loss:[0.9705405],eval_loss:0.9519882798194885 epoch:89,batch_id:8,train_loss:[0.96945614],eval_loss:0.9513733386993408 epoch:90,batch_id:8,train_loss:[0.96838456],eval_loss:0.9507946173350016 epoch:91,batch_id:8,train_loss:[0.96732265],eval_loss:0.9502503275871277 epoch:92,batch_id:8,train_loss:[0.9662684],eval_loss:0.9497395157814026 epoch:93,batch_id:8,train_loss:[0.9652197],eval_loss:0.9492613275845846 epoch:94,batch_id:8,train_loss:[0.9641763],eval_loss:0.9488150080045065 epoch:95,batch_id:8,train_loss:[0.9631378],eval_loss:0.948401133219401 epoch:96,batch_id:8,train_loss:[0.96210533],eval_loss:0.9480193853378296 epoch:97,batch_id:8,train_loss:[0.9610793],eval_loss:0.9476695458094279 epoch:98,batch_id:8,train_loss:[0.96005946],eval_loss:0.9473506410916647 epoch:99,batch_id:8,train_loss:[0.9590464],eval_loss:0.9470618565877279 Final loss: [0.9590464]
<Figure size 432x288 with 1 Axes>
import paddle# place = fluid.CUDAPlace(0) #非develop版本请勿使用GPU版本place = fluid.CPUPlace()with fluid.dygraph.guard(place):
accs = [] # model_dict, _ = fluid.load_dygraph('MyLSTMModel.pdopt')
model_dict, _ = fluid.load_dygraph('MyModel.pdopt')
model = MyModel() # model=MyLSTMModel()
model.load_dict(model_dict) #加载模型参数
val_reader = fluid.io.batch(reader=switch_reader(is_val=True), batch_size=batch_size)
res=[] for batch_id, eval_data in enumerate(val_reader()):
eval_data_x = np.array([x[0] for x in eval_data],np.float32)
eval_data_y = np.array([x[1] for x in eval_data]).astype('float32')
eval_data_x = fluid.dygraph.to_variable(eval_data_x)
eval_data_y = fluid.dygraph.to_variable(eval_data_y)
eval_predict=model(eval_data_x)
res.append(eval_predict)
res
LSTM:
以上就是基于PaddlePaddle2.0.0rc使用LSTM进行北京空气污染序列预测的详细内容,更多请关注php中文网其它相关文章!
每个人都需要一台速度更快、更稳定的 PC。随着时间的推移,垃圾文件、旧注册表数据和不必要的后台进程会占用资源并降低性能。幸运的是,许多工具可以让 Windows 保持平稳运行。
Copyright 2014-2025 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号