这篇文章上次修改于 303 天前,可能其部分内容已经发生变化,如有疑问可询问作者。

模型测试

  • 将上一节训练好并保存了的模型下载到本地项目目录
    请输入图片描述

  • 找一张狗的图片命名为dog2.png , 下载到项目目录下

  • 在项目目录下创建测试文件test.py


import torch import torchvision.transforms from torch import nn from torch.nn import Conv2d, MaxPool2d, Flatten, Linear from PIL import Image class ModelDemo(nn.Module): def __init__(self): super(SeqDemo, self).__init__() self.model = nn.Sequential( Conv2d(3, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 32, 5, padding=2), MaxPool2d(2), Conv2d(32, 64, 5, padding=2), MaxPool2d(2), Flatten(), Linear(1024,64), Linear(64, 10) ) def forward(self, x): x = self.model(x) return x # 加载模型, 因为本地电脑没有gpu,所以要添加参数map_location=torch.device('cpu') mymodel = torch.load("demoModel40.pth", map_location=torch.device('cpu')) image = Image.open("dog2.png") print(image) # 将图片数据重新修改尺寸是32*32的,然后转换成tensor数据类型 transfrom = torchvision.transforms.Compose([ torchvision.transforms.Resize((32,32)), torchvision.transforms.ToTensor() ]) image = transfrom(image) print(image.shape) # 加入一个batchSize image = torch.reshape(image, (1, 3, 32,32)) print(image.shape) output = mymodel(image) # 枚举10种类型名称 cateData = ['飞机', '汽车', '鸟', '猫', '鹿', '狗', '青蛙', '马', '船', '卡车'] print(output) print(output.argmax()) print("图片是{}".format(cateData[output.argmax()]))
  • 看看预测结果
    请输入图片描述