正弦函数预测

通过已知的样本数据对正弦函数进行预测,并绘制出图形。分别设计 LSTM GRU和 RNN网络进行预测。

导包

1
2
3
4
5
6
import numpy as np
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt

数据集生成与划分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
data = []
start = 0
for i in range(200):
x = [np.sin(x/10) for x in range(start, start+11)]
data.append(x)
start = start + 1

data = np.array(data)
data = torch.from_numpy(data)

target = data[:, -1:].type(torch.FloatTensor)
data = data[:, :-1].type(torch.FloatTensor)

train_x = data[:150]
train_y = target[:150]
test_x = data[150:]
test_y = target[150:]

train_dataset = TensorDataset(train_x, train_y)
test_dataset = TensorDataset(test_x, test_y)
train_loader = DataLoader(dataset=train_dataset, batch_size=5, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=5, shuffle=False)

网络构建

  1. nn.RNN(1, 10, batch_first=True): RNN是最基本的循环神经网络结构。它的每个时间步接收输入和前一个时间步的隐藏状态,并输出一个新的隐藏状态。输入维度为1,输出维度为10。batch_first=True表示输入数据的形状为(batch_size, sequence_length, input_size)。
  2. nn.LSTM(1, 10, batch_first=True): 为了解决RNN的长期依赖问题,LSTM引入了一个记忆单元(cell state)和三个门:遗忘门(forget gate)、输入门(input gate)和输出门(output gate)。
  3. nn.GRU(1, 10, batch_first=True): GRU是一种介于RNN和LSTM之间的结构,相较于LSTM更简单。它合并了遗忘门和输入门,仅保留了一个更新门(update gate)和一个重置门(reset gate)。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Model(nn.Module):
def __init__(self):
super().__init__()
self.rnn = nn.RNN(1, 10, batch_first=True)
self.lstm = nn.LSTM(1, 10, batch_first=True)
self.gru = nn.GRU(1, 10, batch_first=True)
self.fc1 = nn.Linear(10, 1)

def forward(self, x, hidden):
# output, hidden = self.rnn(x, hidden)
output, hidden = self.lstm(x, hidden)
# output, hidden = self.gru(x, hidden)
output = self.fc1(output[:, -1, :])
return output

训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
net = Model()
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(net.parameters(), lr=0.001)

h0 = torch.zeros(1, 5, 10)
c0 = torch.zeros(1, 5, 10)

for epoch in range(500):
for i, data in enumerate(train_loader):
x, y = data
x = x.view(-1, 10, 1)

pred = net(x, (h0, c0))
loss = loss_fn(pred, y)

opt.zero_grad()
loss.backward()
opt.step()

if epoch % 50 == 0:
print('epoch: {}, loss: {:.5f}'.format(epoch, loss.item()))

测试

1
2
3
4
5
6
7
8
9
10
11
12
13
preds = []
for i, data in enumerate(test_loader):
x, y = data
x = x.view(-1, 10, 1)
hidden = torch.zeros(1, 5, 10)
pred = net(x, (h0, c0))
preds.append(pred.data.numpy())
print(y.view(1, -1).data)
print(pred.view(1, -1).data, '\n')

plt.scatter(range(len(train_y)), train_y.data.numpy(), marker='o')
plt.scatter(range(150, 200), preds, marker='s')
plt.show()98191

输出:

RNN

训练期间损失输出:

1
2
3
4
5
6
7
8
9
10
epoch: 0, loss: 0.57561
epoch: 50, loss: 0.00049
epoch: 100, loss: 0.00011
epoch: 150, loss: 0.00001
epoch: 200, loss: 0.00004
epoch: 250, loss: 0.00001
epoch: 300, loss: 0.00000
epoch: 350, loss: 0.00001
epoch: 400, loss: 0.00000
epoch: 450, loss: 0.00000

测试结果对比:

tensor([[-0.2879, -0.3821, -0.4724, -0.5581, -0.6381]])
tensor([[-0.2888, -0.3836, -0.4745, -0.5604, -0.6403]]) 

tensor([[-0.7118, -0.7784, -0.8371, -0.8876, -0.9291]])
tensor([[-0.7133, -0.7791, -0.8372, -0.8873, -0.9288]]) 

tensor([[-0.9614, -0.9841, -0.9969, -0.9998, -0.9927]])
tensor([[-0.9611, -0.9838, -0.9966, -0.9994, -0.9925]]) 

tensor([[-0.9756, -0.9488, -0.9126, -0.8672, -0.8132]])
tensor([[-0.9758, -0.9496, -0.9139, -0.8689, -0.8150]]) 

tensor([[-0.7510, -0.6813, -0.6048, -0.5223, -0.4346]])
tensor([[-0.7527, -0.6825, -0.6053, -0.5220, -0.4338]]) 

tensor([[-0.3425, -0.2470, -0.1490, -0.0495,  0.0504]])
tensor([[-0.3415, -0.2462, -0.1488, -0.0499,  0.0496]]) 

tensor([[0.1499, 0.2478, 0.3433, 0.4354, 0.5231]])
tensor([[0.1489, 0.2470, 0.3428, 0.4352, 0.5231]]) 

tensor([[0.6055, 0.6820, 0.7516, 0.8137, 0.8676]])
tensor([[0.6056, 0.6818, 0.7510, 0.8128, 0.8665]]) 

tensor([[0.9129, 0.9491, 0.9758, 0.9928, 0.9998]])
tensor([[0.9117, 0.9476, 0.9740, 0.9907, 0.9975]]) 

tensor([[0.9968, 0.9839, 0.9612, 0.9288, 0.8872]])
tensor([[0.9946, 0.9818, 0.9595, 0.9276, 0.8862]]) 

GRU

训练期间损失输出:

1
2
3
4
5
6
7
8
9
10
epoch: 0, loss: 0.44557
epoch: 50, loss: 0.00044
epoch: 100, loss: 0.00007
epoch: 150, loss: 0.00003
epoch: 200, loss: 0.00001
epoch: 250, loss: 0.00001
epoch: 300, loss: 0.00000
epoch: 350, loss: 0.00003
epoch: 400, loss: 0.00001
epoch: 450, loss: 0.00000

测试结果对比:

tensor([[-0.2879, -0.3821, -0.4724, -0.5581, -0.6381]])
tensor([[-0.2867, -0.3809, -0.4715, -0.5576, -0.6382]]) 

tensor([[-0.7118, -0.7784, -0.8371, -0.8876, -0.9291]])
tensor([[-0.7124, -0.7792, -0.8379, -0.8880, -0.9290]]) 

tensor([[-0.9614, -0.9841, -0.9969, -0.9998, -0.9927]])
tensor([[-0.9606, -0.9826, -0.9949, -0.9976, -0.9906]]) 

tensor([[-0.9756, -0.9488, -0.9126, -0.8672, -0.8132]])
tensor([[-0.9739, -0.9477, -0.9121, -0.8673, -0.8137]]) 

tensor([[-0.7510, -0.6813, -0.6048, -0.5223, -0.4346]])
tensor([[-0.7516, -0.6818, -0.6048, -0.5217, -0.4333]]) 

tensor([[-0.3425, -0.2470, -0.1490, -0.0495,  0.0504]])
tensor([[-0.3407, -0.2450, -0.1471, -0.0481,  0.0513]]) 

tensor([[0.1499, 0.2478, 0.3433, 0.4354, 0.5231]])
tensor([[0.1501, 0.2475, 0.3427, 0.4348, 0.5228]]) 

tensor([[0.6055, 0.6820, 0.7516, 0.8137, 0.8676]])
tensor([[0.6058, 0.6829, 0.7530, 0.8153, 0.8692]]) 

tensor([[0.9129, 0.9491, 0.9758, 0.9928, 0.9998]])
tensor([[0.9142, 0.9499, 0.9760, 0.9924, 0.9991]]) 

tensor([[0.9968, 0.9839, 0.9612, 0.9288, 0.8872]])
tensor([[0.9960, 0.9832, 0.9608, 0.9289, 0.8877]]) 

LSTM

训练期间损失输出:

1
2
3
4
5
6
7
8
9
10
epoch: 0, loss: 0.65284
epoch: 50, loss: 0.00016
epoch: 100, loss: 0.00011
epoch: 150, loss: 0.00001
epoch: 200, loss: 0.00000
epoch: 250, loss: 0.00000
epoch: 300, loss: 0.00000
epoch: 350, loss: 0.00000
epoch: 400, loss: 0.00001
epoch: 450, loss: 0.00004

测试结果对比:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
tensor([[-0.2879, -0.3821, -0.4724, -0.5581, -0.6381]])
tensor([[-0.2881, -0.3823, -0.4726, -0.5581, -0.6380]])

tensor([[-0.7118, -0.7784, -0.8371, -0.8876, -0.9291]])
tensor([[-0.7115, -0.7779, -0.8367, -0.8874, -0.9293]])

tensor([[-0.9614, -0.9841, -0.9969, -0.9998, -0.9927]])
tensor([[-0.9620, -0.9850, -0.9978, -1.0005, -0.9930]])

tensor([[-0.9756, -0.9488, -0.9126, -0.8672, -0.8132]])
tensor([[-0.9759, -0.9493, -0.9137, -0.8691, -0.8157]])

tensor([[-0.7510, -0.6813, -0.6048, -0.5223, -0.4346]])
tensor([[-0.7539, -0.6841, -0.6072, -0.5242, -0.4361]])

tensor([[-0.3425, -0.2470, -0.1490, -0.0495, 0.0504]])
tensor([[-0.3441, -0.2489, -0.1514, -0.0523, 0.0475]])

tensor([[0.1499, 0.2478, 0.3433, 0.4354, 0.5231]])
tensor([[0.1470, 0.2451, 0.3408, 0.4329, 0.5204]])

tensor([[0.6055, 0.6820, 0.7516, 0.8137, 0.8676]])
tensor([[0.6025, 0.6784, 0.7476, 0.8094, 0.8634]])

tensor([[0.9129, 0.9491, 0.9758, 0.9928, 0.9998]])
tensor([[0.9090, 0.9456, 0.9726, 0.9896, 0.9963]])

tensor([[0.9968, 0.9839, 0.9612, 0.9288, 0.8872]])
tensor([[0.9929, 0.9797, 0.9569, 0.9249, 0.8838]])