Pytorch 模型加载保存预测整理
最近在学习pytorch,把踩过的坑,整理一下。##### 1、预训练模型的加载问题在模型加载过程中,常用的有两种方式:###### (1)直接保存加载训好的模型```pythontorch.save(model, 'src/model.pth')# 直接保存模型model = torch.load('src/model.pth')# 直接加载模型```###### (2)只加载模型参数,进行训练
·
最近在学习pytorch,把踩过的坑,整理一下。
1、预训练模型的加载问题
在模型加载过程中,常用的有两种方式:
(1)直接保存加载训好的模型
torch.save(model, 'src/model.pth') # 直接保存模型
model = torch.load('src/model.pth') # 直接加载模型
(2)只加载模型参数,进行训练
model = ... # 训练好的模型
torch.save(model.state_dict(), 'src/params.pth') # 只保存模型参数
# -----------------------------------------------
model = ... # 新的模型(没训练)
model.load_state_dict(torch.load('src/params.pth')) # 加载模型参数,并更新到模型中
(3)模型保存的问题
- A、如果只保存模型参数,其实就把参数用字典的形式保存下来,我们如果只需要其中的几层或者仅只修改后面一层,我们只需把要修改的层替换掉,即可。
#把不属于新模型 model_dict 的参数, 除掉。pretrained_dict是预训练参数
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
- B、如果在使用多卡训练模型时,即使用
model = torch.nn.DataParallel(model, device_ids=device_ids)
,在保存模型参数时,建议使用torch.save(model.module.state_dict(), 'src/params.pth')
,否则,保存的模型参数中的key
的名字会多出module
,具体可根据程序进行修改。
2、用训练好的模型预测代码实现
#!/usr/bin/python
# -*- coding: UTF-8 -*-
import torchvision as tv
import torchvision.transforms as transforms
import torch as t
from PIL import Image
def pridict():
device = t.device("cuda" if t.cuda.is_available() else "cpu")
model = tv.models.resnet18(pretrained=True) # 创建一个模型
model = model.to(device)
model.eval() # 预测模式
# 获取测试图片,并行相应的处理
img = Image.open('cat.jpg')
transform = transforms.Compose([transforms.Resize(256), # 重置图像分辨率
transforms.CenterCrop(224), # 中心裁剪
transforms.ToTensor(), ])
img = img.convert("RGB") # 如果是标准的RGB格式,则可以不加
img = transform(img)
img = img.unsqueeze(0)
img = img.to(device)
with t.no_grad():
py = model(img)
_, predicted = t.max(py, 1) # 获取分类结果
classIndex_ = predicted[0]
print('预测结果', classIndex_)
if __name__ == '__main__':
pridict()
声明: 总结学习,有问题或不当之处,可以批评指正哦,谢谢。
参考链接
更多推荐
已为社区贡献1条内容
所有评论(0)