联邦学习/机器学习笔记(二)一维线性回归+pytorch代码实现
联邦学习/机器学习python+python一维线性回归
·
联邦学习/机器学习笔记(二)
一维线性回归+pytorch代码实现
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
'''一维线性回归'''
# 建立数据
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
[9.779], [6.182], [7.59], [2.167], [7.042],
[10.781], [5.313], [7.997], [3.1]], dtype = np.float32)
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
[3.336], [2.596], [2.53], [1.221], [2.827],
[3.465], [1.65], [2.904], [1.3]], dtype = np.float32)
# 将数据转换换成tensor
x_train = torch.from_numpy(x_train)
y_train = torch.from_numpy(y_train)
# 建立模型
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出是一维的
def forward(self, x):
out = self.linear(x)
return out
if torch.cuda.is_available():
model = LinearRegression().cuda()
else:
model = LinearRegression()
# 定义损失函数和优化函数
criterion = nn.MSELoss() # 使用均方误差
# optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练模型
num_epochs = 1500
for epoch in range(num_epochs):
if torch.cuda.is_available():
inputs = Variable(x_train).cuda()
target = Variable(y_train).cuda()
else:
inputs = Variable(x_train)
target = Variable(y_train)
# 前向传播
out = model(inputs) #得到网络向前传播的结果,实际上model(inputs)等价于model.forward(inputs)
loss = criterion(out, target)
# 反向传播
optimizer.zero_grad() #归零梯度
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
print('Epoch[{}/{}], loss: {: .6f}'.format(epoch + 1, num_epochs, loss.item()))
# 预测结果
model.eval()
predict = model(Variable(x_train))
# print(predict)
predict = predict.data.numpy()
plt.plot(x_train.numpy(), y_train.numpy(), 'ro', label = 'Original data')
plt.plot(x_train.numpy(), predict, label = 'Fitting Line')
plt.show()
运行结果
C:\Users\Administrator\anaconda3\python.exe "D:/paper reading/code/learningcode/oneline.py"
Epoch[20/1500], loss: 0.232388
Epoch[40/1500], loss: 0.226161
Epoch[60/1500], loss: 0.220542
Epoch[80/1500], loss: 0.215471
Epoch[100/1500], loss: 0.210896
Epoch[120/1500], loss: 0.206768
Epoch[140/1500], loss: 0.203043
Epoch[160/1500], loss: 0.199682
Epoch[180/1500], loss: 0.196649
Epoch[200/1500], loss: 0.193912
Epoch[220/1500], loss: 0.191443
Epoch[240/1500], loss: 0.189215
Epoch[260/1500], loss: 0.187204
Epoch[280/1500], loss: 0.185390
Epoch[300/1500], loss: 0.183753
Epoch[320/1500], loss: 0.182276
Epoch[340/1500], loss: 0.180943
Epoch[360/1500], loss: 0.179741
Epoch[380/1500], loss: 0.178656
Epoch[400/1500], loss: 0.177676
Epoch[420/1500], loss: 0.176793
Epoch[440/1500], loss: 0.175996
Epoch[460/1500], loss: 0.175276
Epoch[480/1500], loss: 0.174627
Epoch[500/1500], loss: 0.174041
Epoch[520/1500], loss: 0.173513
Epoch[540/1500], loss: 0.173036
Epoch[560/1500], loss: 0.172606
Epoch[580/1500], loss: 0.172217
Epoch[600/1500], loss: 0.171867
Epoch[620/1500], loss: 0.171551
Epoch[640/1500], loss: 0.171266
Epoch[660/1500], loss: 0.171008
Epoch[680/1500], loss: 0.170776
Epoch[700/1500], loss: 0.170567
Epoch[720/1500], loss: 0.170377
Epoch[740/1500], loss: 0.170207
Epoch[760/1500], loss: 0.170053
Epoch[780/1500], loss: 0.169914
Epoch[800/1500], loss: 0.169789
Epoch[820/1500], loss: 0.169675
Epoch[840/1500], loss: 0.169573
Epoch[860/1500], loss: 0.169481
Epoch[880/1500], loss: 0.169398
Epoch[900/1500], loss: 0.169323
Epoch[920/1500], loss: 0.169256
Epoch[940/1500], loss: 0.169194
Epoch[960/1500], loss: 0.169139
Epoch[980/1500], loss: 0.169090
Epoch[1000/1500], loss: 0.169045
Epoch[1020/1500], loss: 0.169004
Epoch[1040/1500], loss: 0.168968
Epoch[1060/1500], loss: 0.168935
Epoch[1080/1500], loss: 0.168905
Epoch[1100/1500], loss: 0.168878
Epoch[1120/1500], loss: 0.168854
Epoch[1140/1500], loss: 0.168832
Epoch[1160/1500], loss: 0.168813
Epoch[1180/1500], loss: 0.168795
Epoch[1200/1500], loss: 0.168779
Epoch[1220/1500], loss: 0.168764
Epoch[1240/1500], loss: 0.168751
Epoch[1260/1500], loss: 0.168739
Epoch[1280/1500], loss: 0.168729
Epoch[1300/1500], loss: 0.168719
Epoch[1320/1500], loss: 0.168710
Epoch[1340/1500], loss: 0.168703
Epoch[1360/1500], loss: 0.168696
Epoch[1380/1500], loss: 0.168689
Epoch[1400/1500], loss: 0.168684
Epoch[1420/1500], loss: 0.168678
Epoch[1440/1500], loss: 0.168674
Epoch[1460/1500], loss: 0.168669
Epoch[1480/1500], loss: 0.168666
Epoch[1500/1500], loss: 0.168662
Process finished with exit code 0
更多推荐
已为社区贡献3条内容
所有评论(0)