这篇文章上次修改于 303 天前,可能其部分内容已经发生变化,如有疑问可询问作者。
模型测试
将上一节训练好并保存了的模型下载到本地项目目录
找一张狗的图片命名为dog2.png , 下载到项目目录下
在项目目录下创建测试文件test.py
import torch
import torchvision.transforms
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
from PIL import Image
class ModelDemo(nn.Module):
def __init__(self):
super(SeqDemo, self).__init__()
self.model = nn.Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64, 10)
)
def forward(self, x):
x = self.model(x)
return x
# 加载模型, 因为本地电脑没有gpu,所以要添加参数map_location=torch.device('cpu')
mymodel = torch.load("demoModel40.pth", map_location=torch.device('cpu'))
image = Image.open("dog2.png")
print(image)
# 将图片数据重新修改尺寸是32*32的,然后转换成tensor数据类型
transfrom = torchvision.transforms.Compose([
torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()
])
image = transfrom(image)
print(image.shape)
# 加入一个batchSize
image = torch.reshape(image, (1, 3, 32,32))
print(image.shape)
output = mymodel(image)
# 枚举10种类型名称
cateData = ['飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车']
print(output)
print(output.argmax())
print("图片是{}".format(cateData[output.argmax()]))
- 看看预测结果
没有评论
博主关闭了评论...