import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, TensorDataset from torchvision import transforms from torchvision import datasets from torchvision import models
数据集处理
将图片分类放到文件夹中以满足 ImageFolder 读取
1 2
mv dog.* dog mvcat.* cat
将25000张训练集图片划分成 20000 张训练集和 5000 张测试集,分别放入 train 和 test 文件夹中并读取
net = models.resnet18(pretrained=True) for param in net.parameters(): param.requires_grad = False features = net.fc.in_features net.fc = nn.Linear(features, 2)
if torch.cuda.is_available(): net = net.cuda() net.fc = net.fc.cuda()
for epoch inrange(20): for i, data inenumerate(train_loader): x, y = data if torch.cuda.is_available(): x = x.cuda() y = y.cuda() pred = net(x) loss = loss_fn(pred, y)
opt.zero_grad() loss.backward() opt.step()
if epoch % 2 == 0: print(epoch, loss.item())
torch.save(net, 'catvsdog_model.pth')
测试
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defrightness(predictions, labels): pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels)
rights = 0 length = 0 for i, data inenumerate(test_loader): x, y = data x = x.cuda() y = y.cuda() net.eval() pred = net(x) right = pred.argmax(dim=1) == y rights += rightness(pred, y)[0] length += rightness(pred, y)[1]