Keras库的学习历程:基于LSTM模型进行时序分析
6 min read
Page Views
1.问题引入
基于LSTM模型对每日伦敦市场下午交易时段的黄金基准价格(以美元结算)进行预测。
数据下载地址:
https://fsapi-china.gold.org/api/goldprice/v11/chart/main?period=Max
2.python程序
import os
import time
import requests
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from keras.layers import LSTM, Dense, Dropout
from keras.callbacks import EarlyStopping
from keras.models import Sequential
from sklearn.metrics import explained_variance_score, mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
plt.rcParams['font.sans-serif'] = ['SimSun']
plt.rcParams['axes.unicode_minus'] = False
plt.rc(group='axes', unicode_minus=False)
url = 'https://fsapi-china.gold.org/api/goldprice/v11/chart/main?period=Max'
response = requests.get(url)
html = response.json()['chartData']
gold_price_keys = list(html.keys())[:8]
dict = {}
for gold_price_key in gold_price_keys:
gold_price_values = html[gold_price_key]
for sublist in gold_price_values:
sublist[0] = time.strftime('%Y-%m-%d', time.localtime(sublist[0] / 1000))
dict.update(zip([gold_price_key], [gold_price_values]))
print('数据集涵盖的指标:\n', list(dict.keys()))
data = pd.DataFrame(dict.get('lbma_pm_usd'), columns=['Date', 'Truth'])
data = data.set_index('Date')
print('\n每日伦敦市场下午交易时段的黄金基准价格(以美元结算):\n', data)
scaler = MinMaxScaler(feature_range=(0, 1))
data_std = scaler.fit_transform(data.values)
moving_win_size = 5
all_x, all_y = [], []
for i in range(len(data_std) - moving_win_size):
x = data_std[i:i + moving_win_size]
y = data_std[i + moving_win_size]
all_x.append(x)
all_y.append(y)
all_x, all_y = np.array(all_x), np.array(all_y)
train_ds_size = round(all_x.shape[0] * 0.8)
train_x, train_y = all_x[:train_ds_size], all_y[:train_ds_size]
test_x, test_y = all_x[train_ds_size:], all_y[train_ds_size:]
print('\nLSTM模型训练过程:')
model = Sequential()
model.add(LSTM(units=64, return_sequences=True, input_shape=(train_x.shape[1], 1)))
model.add(Dropout(0.2))
model.add(LSTM(units=32))
model.add(Dropout(0.2))
model.add(Dense(units=16, activation='relu'))
model.add(Dense(units=1))
model.compile(optimizer='adam', loss='mse', metrics=['mape'])
callback = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
model.fit(train_x, train_y, batch_size=32, validation_split=0.2, callbacks=[callback], epochs=500)
print(model.summary())
pred = scaler.inverse_transform(model.predict(test_x))
train = data[:train_ds_size + moving_win_size]
test = data[train_ds_size + moving_win_size:]
test = test.assign(Predict=pred)
fig, ax = plt.subplots(1, 1)
ax.plot(train['Truth'], 'g')
ax.plot(test['Truth'], 'r')
ax.plot(test['Predict'], 'b')
ax.legend(['训练数据', '测试数据', '预测数据'])
ax.set_title('基于LSTM模型对黄金基准价格(以美元结算)进行预测')
fig.autofmt_xdate(rotation=45)
plt.gca().xaxis.set_major_locator(plt.MultipleLocator(30))
test = test.assign(Shifted=test['Truth'].shift(1))
test.iat[0, -1] = train.iat[-1, -1]
print('\n测试数据|预测数据|上一条数据对比:\n', test)
metrics = {
'MSE': lambda y_true, y_pred: mean_squared_error(y_true, y_pred, squared=True),
'RMSE': lambda y_true, y_pred: mean_squared_error(y_true, y_pred, squared=False),
'MAE': mean_absolute_error,
'EV': explained_variance_score,
'R2': r2_score
}
dataframe = pd.DataFrame(index=list(metrics.keys()), columns=['Predict', 'Shifted'])
for metric, func in metrics.items():
dataframe.at[metric, 'Predict'] = func(test['Truth'], test['Predict'])
dataframe.at[metric, 'Shifted'] = func(test['Truth'], test['Shifted'])
print('\n模型评估结果:\n', dataframe)
plt.show()
3.输出结果
数据集涵盖的指标:
['lbma_am_usd', 'lbma_am_gbp', 'lbma_am_eur', 'lbma_pm_usd', 'lbma_pm_gbp', 'lbma_pm_eur', 'sge_am_cny', 'sge_pm_cny']
每日伦敦市场下午交易时段的黄金基准价格(以美元结算):
Truth
Date
1970-02-11 35.03
1970-03-24 35.17
1970-05-04 35.90
1970-06-12 35.64
1970-07-23 35.30
... ...
2025-04-03 3118.10
2025-05-14 3191.95
2025-06-24 3302.50
2025-07-14 3351.15
2025-07-15 3345.10
[498 rows x 1 columns]
LSTM模型训练过程:
Epoch 1/500
10/10 [==============================] - 2s 49ms/step - loss: 0.0038 - mape: 206.7163 - val_loss: 0.0457 - val_mape: 54.8422
Epoch 2/500
10/10 [==============================] - 0s 5ms/step - loss: 0.0015 - mape: 417.1404 - val_loss: 0.0426 - val_mape: 53.1359
Epoch 3/500
10/10 [==============================] - 0s 5ms/step - loss: 0.0011 - mape: 303.3871 - val_loss: 0.0482 - val_mape: 58.7283
Epoch 4/500
10/10 [==============================] - 0s 5ms/step - loss: 9.5144e-04 - mape: 281.8955 - val_loss: 0.0326 - val_mape: 46.8034
Epoch 5/500
10/10 [==============================] - 0s 5ms/step - loss: 7.7782e-04 - mape: 284.0965 - val_loss: 0.0276 - val_mape: 44.3014
Epoch 6/500
10/10 [==============================] - 0s 6ms/step - loss: 5.7018e-04 - mape: 217.6779 - val_loss: 0.0132 - val_mape: 30.4231
Epoch 7/500
10/10 [==============================] - 0s 6ms/step - loss: 3.2493e-04 - mape: 144.3830 - val_loss: 0.0048 - val_mape: 18.7642
Epoch 8/500
10/10 [==============================] - 0s 5ms/step - loss: 2.3188e-04 - mape: 84.9539 - val_loss: 0.0013 - val_mape: 8.9728
Epoch 9/500
10/10 [==============================] - 0s 5ms/step - loss: 2.3613e-04 - mape: 70.6797 - val_loss: 0.0013 - val_mape: 8.4729
Epoch 10/500
10/10 [==============================] - 0s 5ms/step - loss: 2.3620e-04 - mape: 63.4833 - val_loss: 0.0013 - val_mape: 9.3378
Epoch 11/500
10/10 [==============================] - 0s 6ms/step - loss: 2.1308e-04 - mape: 64.9971 - val_loss: 0.0016 - val_mape: 10.5365
Epoch 12/500
10/10 [==============================] - 0s 5ms/step - loss: 2.3602e-04 - mape: 72.1374 - val_loss: 0.0015 - val_mape: 9.6465
Epoch 13/500
10/10 [==============================] - 0s 5ms/step - loss: 2.1160e-04 - mape: 59.4635 - val_loss: 0.0020 - val_mape: 11.8361
Epoch 14/500
10/10 [==============================] - 0s 4ms/step - loss: 2.3469e-04 - mape: 53.3759 - val_loss: 0.0019 - val_mape: 11.3738
Epoch 15/500
10/10 [==============================] - 0s 5ms/step - loss: 2.1404e-04 - mape: 48.5358 - val_loss: 0.0020 - val_mape: 11.5357
Epoch 16/500
10/10 [==============================] - 0s 6ms/step - loss: 2.1892e-04 - mape: 39.3191 - val_loss: 0.0028 - val_mape: 13.9195
Epoch 17/500
10/10 [==============================] - 0s 5ms/step - loss: 2.1128e-04 - mape: 41.0462 - val_loss: 0.0021 - val_mape: 11.4239
Epoch 18/500
10/10 [==============================] - 0s 5ms/step - loss: 2.0450e-04 - mape: 31.5741 - val_loss: 0.0024 - val_mape: 12.6676
Epoch 19/500
10/10 [==============================] - 0s 5ms/step - loss: 2.0496e-04 - mape: 36.7726 - val_loss: 0.0035 - val_mape: 15.3388
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 5, 64) 16896
dropout (Dropout) (None, 5, 64) 0
lstm_1 (LSTM) (None, 32) 12416
dropout_1 (Dropout) (None, 32) 0
dense (Dense) (None, 16) 528
dense_1 (Dense) (None, 1) 17
=================================================================
Total params: 29857 (116.63 KB)
Trainable params: 29857 (116.63 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
None
4/4 [==============================] - 0s 1ms/step
测试数据|预测数据|上一条数据对比:
Truth Predict Shifted
Date
2014-07-30 1294.50 1310.929688 1293.00
2014-09-09 1255.75 1336.798340 1294.50
2014-10-20 1244.50 1324.560669 1255.75
2014-11-28 1182.75 1318.717529 1244.50
2015-01-08 1215.50 1305.177368 1182.75
... ... ... ...
2025-04-03 3118.10 2724.807129 2934.15
2025-05-14 3191.95 2812.505859 3118.10
2025-06-24 3302.50 2869.609619 3191.95
2025-07-14 3351.15 2996.434814 3302.50
2025-07-15 3345.10 3152.391602 3351.15
[99 rows x 3 columns]
模型评估结果:
Predict Shifted
MSE 20435.330881 5904.579394
RMSE 142.952198 76.841261
MAE 110.389102 56.673737
EV 0.935066 0.981972
R2 0.932708 0.980557
由上述评估结果来看,LSTM模型的预测效果比不上直接使用上一条数据当作预测值的效果,后续模型的改进还需要给最近日期更高的权重才行。

Last updated on 2025-06-30