libtorch学习笔记(7)- VGG网络训练和测试
VGG网络训练和测试
简单介绍
VGG是卷积网络里面比较常见的网络模型,相比LeNet要复杂一些,但是都属于拓补结构简单直接的前置反馈网络,详细信息可参考论文VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION,VGG网络能够提取更多的图像特征,最后输出的特诊向量信息量更丰富,所以可以进行更大规模的分类,前面介绍的LeNet5可以产生10个分类,分别对应0~9, VGG可以产生上万个分类,识别更多的类型。VGG也是Faster RCNN的基础,Faster RCNN在现实当中实用性更强,能在任意图像内进行目标定位,然后再进行目标识别。 下图是从论文中截取的一张网络配置图,并加上代码中对应的层: 这张表后面结合代码再做详细描述,与前面笔记中提到的LeNet5相比:
Network
网络层数
权重层数
参数个数
LeNet5
7
5
61706
VGG16(D)
39
16
138357544
可想而知VGG要比LeNet5复杂很多,运算量也大很多,训练时间更长,训练的网络状态所占空间也越大。
在我的机器上(MacBook Pro 2017), 用CPU训练,60000张MNIST训练图片(1x28x28)2轮学习花了10分钟左右,10000张测试图片花了10秒,但是8000张左右猫狗训练集(3x可变长宽)2轮学习花了6.7个小时, 2000张测试图片识别花了11分钟左右。GPU可能快很多,目前没试过。
从上表中也能看出一般网络模型命名规律:网络模型名 + 权重层数,所以有LeNet-5, VGG-11, VGG-16和VGG-19这些名称。
网络构建
根据上述论文,选择ConvNet Configuration D,也称作VGG16,基于c++ libtorch库用如下代码创建了它,在上图中也标出了每层对应的module名称,这些网络层的命令是,模型名称缩写+所在第几层,如C29,就是卷积层(Convolutional network, C)在本网络中位于第29层, FC38就是全连接层(FullConnection, FC)在此网络中位于第38层。 另外有些网络层就是做一个简单操作,比如RELU, MaxPool等,就不注册网络层,具体就在forward中当作function来in-place处理。
VGGNet::VGGNet(int num_classes)
: C1 (register_module("C1", Conv2d(Conv2dOptions( 3, 64, 3).padding(1))))
, C3 (register_module("C3", Conv2d(Conv2dOptions( 64, 64, 3).padding(1))))
, C6 (register_module("C6", Conv2d(Conv2dOptions( 64, 128, 3).padding(1))))
, C8 (register_module("C8", Conv2d(Conv2dOptions(128, 128, 3).padding(1))))
, C11 (register_module("C11", Conv2d(Conv2dOptions(128, 256, 3).padding(1))))
, C13 (register_module("C13", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
, C15 (register_module("C15", Conv2d(Conv2dOptions(256, 256, 3).padding(1))))
, C18 (register_module("C18", Conv2d(Conv2dOptions(256, 512, 3).padding(1))))
, C20 (register_module("C20", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C22 (register_module("C22", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C25 (register_module("C25", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C27 (register_module("C27", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, C29 (register_module("C29", Conv2d(Conv2dOptions(512, 512, 3).padding(1))))
, FC32(register_module("FC32",Linear(512 * 7 * 7, 4096)))
, FC35(register_module("FC35",Linear(4096, 4096)))
, FC38(register_module("FC38",Linear(4096, num_classes)))
{
...
}
torch::Tensor VGGNet::forward(torch::Tensor input)
{
namespace F = torch::nn::functional;
// block#1
auto x = F::max_pool2d(F::relu(C3(F::relu(C1(input)))), F::MaxPool2dFuncOptions(2));
// block#2
x = F::max_pool2d(F::relu(C8(F::relu(C6(x)))), F::MaxPool2dFuncOptions(2));
// block#3
x = F::max_pool2d(F::relu(C15(F::relu(C13(F::relu(C11(x)))))), F::MaxPool2dFuncOptions(2));
// block#4
x = F::max_pool2d(F::relu(C22(F::relu(C20(F