网络训练 VGG16 network training

利用Pytorch和CIFAR10数据集训练VGG16网络,附代码

Structure of VGG16

数据集的选取

我就选了个CIFAR10,别的懒得整了。代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
transform = transforms.Compose([transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')

拿Windows练的别忘了把num_workers改成0,这个多线程不支持Windows

网络的选择

VGG16就是拿来玩玩,他的问题是参数太多,几亿个,节点的参数要占500多MB,训练呢也比较慢。

Parameters

在使用VGG16来分类CIFAR10的时候呢,要注意一下输入和输出。因为VGG16一开始是用来给ImageNet比赛用的,所以图像尺寸是224*224,输出是1000个分类。 CIFAR10是32*32的图片,分类是10种。所以我对CIFAR10的图像进行了插值,放大到了224*224,同时把最后的full connect layer的大小换成了10。

别的网络结构我就没有改动了,大家在做的时候可以考虑更改下网络来对CIFAR10优化一下。

网络搭建

因为是自己练手,没用torch里自带的VGG16网络,还是一层层搭的,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()

self.pool = nn.MaxPool2d(2, 2)

self.batchNorm1 = nn.BatchNorm2d(64)
self.batchNorm2 = nn.BatchNorm2d(128)
self.batchNorm3 = nn.BatchNorm2d(256)
self.batchNorm4 = nn.BatchNorm2d(512)
self.batchNorm5 = nn.BatchNorm2d(512)

self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1, bias=False)
self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1, bias=False)
self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1, bias=False)
self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1, bias=False)
self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1, bias=False)
self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1, bias=False)
self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1, bias=False)
self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1, bias=False)
self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1, bias=False)
self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1, bias=False)
self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1, bias=False)
self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1, bias=False)
self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1, bias=False)

self.fc1 = nn.Linear(512 * 7 * 7, 4096)
self.fc2 = nn.Linear(4096, 4096)
self.fc3 = nn.Linear(4096, 10)
self.drop = nn.Dropout(p=0.5)

def forward(self, x):
x = F.relu(self.batchNorm1(self.conv1_1(x)))
x = F.relu(self.batchNorm1(self.conv1_2(x)))
x = self.pool(x)

x = F.relu(self.batchNorm2(self.conv2_1(x)))
x = F.relu(self.batchNorm2(self.conv2_2(x)))
x = self.pool(x)

x = F.relu(self.batchNorm3(self.conv3_1(x)))
x = F.relu(self.batchNorm3(self.conv3_2(x)))
x = F.relu(self.batchNorm3(self.conv3_3(x)))
x = self.pool(x)

x = F.relu(self.batchNorm4(self.conv4_1(x)))
x = F.relu(self.batchNorm4(self.conv4_2(x)))
x = F.relu(self.batchNorm4(self.conv4_3(x)))
x = self.pool(x)

x = F.relu(self.batchNorm5(self.conv5_1(x)))
x = F.relu(self.batchNorm5(self.conv5_2(x)))
x = F.relu(self.batchNorm5(self.conv5_3(x)))
x = self.pool(x)

x = x.view(-1, 512 * 7 * 7)

x = F.relu(self.fc1(x))
# x = self.drop(x)
x = F.relu(self.fc2(x))
# x = self.drop(x)

x = self.fc3(x)
return x


net = Net()

里面注释掉的两个x = self.drop(x)是用来防止网络过拟合的,不过加了这玩意训练的时候loss下降的实在太慢了,跑了两个epoch之后就放弃用它了。

训练过程

Learning rate 和 momentum分别是0.001和0.9,非常常见的设置。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
PATH = './cifar_VGGnet.pth'
net.to(device)

for epoch in range(8):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()

outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()
if i % 200 == 199:
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
torch.save(net.state_dict(), PATH)

print('Finished Training')

然后你就可以愉快的跑起来了,我的1660Ti跑一个epoch大概要23分钟,一共跑了十几个,最后loss从2.303减少到了0.078。

测试结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

if total % 400 == 0:
print(100 * correct / total)

print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
1
Accuracy of the network on the 10000 test images: 83 %

在网上看了看别人的VGG16分类CIFAR10的准确率可以达到90%,不过他训练了40代,由于时间关系,我就没有接着训练了,毕竟只是个练手项目。

完整代码可见VGG-Net-16

Comments