.. _sec_vgg: Networks Using Blocks (VGG) =========================== While AlexNet offered empirical evidence that deep CNNs can achieve good results, it did not provide a general template to guide subsequent researchers in designing new networks. In the following sections, we will introduce several heuristic concepts commonly used to design deep networks. Progress in this field mirrors that of VLSI (very large scale integration) in chip design where engineers moved from placing transistors to logical elements to logic blocks :cite:`Mead.1980`. Similarly, the design of neural network architectures has grown progressively more abstract, with researchers moving from thinking in terms of individual neurons to whole layers, and now to blocks, repeating patterns of layers. A decade later, this has now progressed to researchers using entire trained models to repurpose them for different, albeit related, tasks. Such large pretrained models are typically called *foundation models* :cite:`bommasani2021opportunities`. Back to network design. The idea of using blocks first emerged from the Visual Geometry Group (VGG) at Oxford University, in their eponymously-named *VGG* network :cite:`Simonyan.Zisserman.2014`. It is easy to implement these repeated structures in code with any modern deep learning framework by using loops and subroutines. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import torch from torch import nn from d2l import torch as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python from mxnet import init, np, npx from mxnet.gluon import nn from d2l import mxnet as d2l npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import jax from flax import linen as nn from d2l import jax as d2l .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python import tensorflow as tf from d2l import tensorflow as d2l .. raw:: html
.. raw:: html
.. _subsec_vgg-blocks: VGG Blocks ---------- The basic building block of CNNs is a sequence of the following: (i) a convolutional layer with padding to maintain the resolution, (ii) a nonlinearity such as a ReLU, (iii) a pooling layer such as max-pooling to reduce the resolution. One of the problems with this approach is that the spatial resolution decreases quite rapidly. In particular, this imposes a hard limit of :math:`\log_2 d` convolutional layers on the network before all dimensions (:math:`d`) are used up. For instance, in the case of ImageNet, it would be impossible to have more than 8 convolutional layers in this way. The key idea of :cite:t:`Simonyan.Zisserman.2014` was to use *multiple* convolutions in between downsampling via max-pooling in the form of a block. They were primarily interested in whether deep or wide networks perform better. For instance, the successive application of two :math:`3 \times 3` convolutions touches the same pixels as a single :math:`5 \times 5` convolution does. At the same time, the latter uses approximately as many parameters (:math:`25 \cdot c^2`) as three :math:`3 \times 3` convolutions do (:math:`3 \cdot 9 \cdot c^2`). In a rather detailed analysis they showed that deep and narrow networks significantly outperform their shallow counterparts. This set deep learning on a quest for ever deeper networks with over 100 layers for typical applications. Stacking :math:`3 \times 3` convolutions has become a gold standard in later deep networks (a design decision only to be revisited recently by :cite:t:`liu2022convnet`). Consequently, fast implementations for small convolutions have become a staple on GPUs :cite:`lavin2016fast`. Back to VGG: a VGG block consists of a *sequence* of convolutions with :math:`3\times3` kernels with padding of 1 (keeping height and width) followed by a :math:`2 \times 2` max-pooling layer with stride of 2 (halving height and width after each block). In the code below, we define a function called ``vgg_block`` to implement one VGG block. The function below takes two arguments, corresponding to the number of convolutional layers ``num_convs`` and the number of output channels ``num_channels``. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def vgg_block(num_convs, out_channels): layers = [] for _ in range(num_convs): layers.append(nn.LazyConv2d(out_channels, kernel_size=3, padding=1)) layers.append(nn.ReLU()) layers.append(nn.MaxPool2d(kernel_size=2,stride=2)) return nn.Sequential(*layers) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def vgg_block(num_convs, num_channels): blk = nn.Sequential() for _ in range(num_convs): blk.add(nn.Conv2D(num_channels, kernel_size=3, padding=1, activation='relu')) blk.add(nn.MaxPool2D(pool_size=2, strides=2)) return blk .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def vgg_block(num_convs, out_channels): layers = [] for _ in range(num_convs): layers.append(nn.Conv(out_channels, kernel_size=(3, 3), padding=(1, 1))) layers.append(nn.relu) layers.append(lambda x: nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))) return nn.Sequential(layers) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def vgg_block(num_convs, num_channels): blk = tf.keras.models.Sequential() for _ in range(num_convs): blk.add( tf.keras.layers.Conv2D(num_channels, kernel_size=3, padding='same', activation='relu')) blk.add(tf.keras.layers.MaxPool2D(pool_size=2, strides=2)) return blk .. raw:: html
.. raw:: html
.. _subsec_vgg-network: VGG Network ----------- Like AlexNet and LeNet, the VGG Network can be partitioned into two parts: the first consisting mostly of convolutional and pooling layers and the second consisting of fully connected layers that are identical to those in AlexNet. The key difference is that the convolutional layers are grouped in nonlinear transformations that leave the dimensonality unchanged, followed by a resolution-reduction step, as depicted in :numref:`fig_vgg`. .. _fig_vgg: .. figure:: ../img/vgg.svg :width: 400px From AlexNet to VGG. The key difference is that VGG consists of blocks of layers, whereas AlexNet’s layers are all designed individually. The convolutional part of the network connects several VGG blocks from :numref:`fig_vgg` (also defined in the ``vgg_block`` function) in succession. This grouping of convolutions is a pattern that has remained almost unchanged over the past decade, although the specific choice of operations has undergone considerable modifications. The variable ``arch`` consists of a list of tuples (one per block), where each contains two values: the number of convolutional layers and the number of output channels, which are precisely the arguments required to call the ``vgg_block`` function. As such, VGG defines a *family* of networks rather than just a specific manifestation. To build a specific network we simply iterate over ``arch`` to compose the blocks. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class VGG(d2l.Classifier): def __init__(self, arch, lr=0.1, num_classes=10): super().__init__() self.save_hyperparameters() conv_blks = [] for (num_convs, out_channels) in arch: conv_blks.append(vgg_block(num_convs, out_channels)) self.net = nn.Sequential( *conv_blks, nn.Flatten(), nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5), nn.LazyLinear(4096), nn.ReLU(), nn.Dropout(0.5), nn.LazyLinear(num_classes)) self.net.apply(d2l.init_cnn) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class VGG(d2l.Classifier): def __init__(self, arch, lr=0.1, num_classes=10): super().__init__() self.save_hyperparameters() self.net = nn.Sequential() for (num_convs, num_channels) in arch: self.net.add(vgg_block(num_convs, num_channels)) self.net.add(nn.Dense(4096, activation='relu'), nn.Dropout(0.5), nn.Dense(4096, activation='relu'), nn.Dropout(0.5), nn.Dense(num_classes)) self.net.initialize(init.Xavier()) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class VGG(d2l.Classifier): arch: list lr: float = 0.1 num_classes: int = 10 training: bool = True def setup(self): conv_blks = [] for (num_convs, out_channels) in self.arch: conv_blks.append(vgg_block(num_convs, out_channels)) self.net = nn.Sequential([ *conv_blks, lambda x: x.reshape((x.shape[0], -1)), # flatten nn.Dense(4096), nn.relu, nn.Dropout(0.5, deterministic=not self.training), nn.Dense(4096), nn.relu, nn.Dropout(0.5, deterministic=not self.training), nn.Dense(self.num_classes)]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class VGG(d2l.Classifier): def __init__(self, arch, lr=0.1, num_classes=10): super().__init__() self.save_hyperparameters() self.net = tf.keras.models.Sequential() for (num_convs, num_channels) in arch: self.net.add(vgg_block(num_convs, num_channels)) self.net.add( tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(4096, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(4096, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(num_classes)])) .. raw:: html
.. raw:: html
The original VGG network had five convolutional blocks, among which the first two have one convolutional layer each and the latter three contain two convolutional layers each. The first block has 64 output channels and each subsequent block doubles the number of output channels, until that number reaches 512. Since this network uses eight convolutional layers and three fully connected layers, it is often called VGG-11. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary( (1, 1, 224, 224)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: torch.Size([1, 64, 112, 112]) Sequential output shape: torch.Size([1, 128, 56, 56]) Sequential output shape: torch.Size([1, 256, 28, 28]) Sequential output shape: torch.Size([1, 512, 14, 14]) Sequential output shape: torch.Size([1, 512, 7, 7]) Flatten output shape: torch.Size([1, 25088]) Linear output shape: torch.Size([1, 4096]) ReLU output shape: torch.Size([1, 4096]) Dropout output shape: torch.Size([1, 4096]) Linear output shape: torch.Size([1, 4096]) ReLU output shape: torch.Size([1, 4096]) Dropout output shape: torch.Size([1, 4096]) Linear output shape: torch.Size([1, 10]) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary( (1, 1, 224, 224)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: (1, 64, 112, 112) Sequential output shape: (1, 128, 56, 56) Sequential output shape: (1, 256, 28, 28) Sequential output shape: (1, 512, 14, 14) Sequential output shape: (1, 512, 7, 7) Dense output shape: (1, 4096) Dropout output shape: (1, 4096) Dense output shape: (1, 4096) Dropout output shape: (1, 4096) Dense output shape: (1, 10) [22:40:53] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)), training=False).layer_summary((1, 224, 224, 1)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: (1, 112, 112, 64) Sequential output shape: (1, 56, 56, 128) Sequential output shape: (1, 28, 28, 256) Sequential output shape: (1, 14, 14, 512) Sequential output shape: (1, 7, 7, 512) function output shape: (1, 25088) Dense output shape: (1, 4096) custom_jvp output shape: (1, 4096) Dropout output shape: (1, 4096) Dense output shape: (1, 4096) custom_jvp output shape: (1, 4096) Dropout output shape: (1, 4096) Dense output shape: (1, 10) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python VGG(arch=((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))).layer_summary( (1, 224, 224, 1)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output Sequential output shape: (1, 112, 112, 64) Sequential output shape: (1, 56, 56, 128) Sequential output shape: (1, 28, 28, 256) Sequential output shape: (1, 14, 14, 512) Sequential output shape: (1, 7, 7, 512) Sequential output shape: (1, 10) .. raw:: html
.. raw:: html
As you can see, we halve height and width at each block, finally reaching a height and width of 7 before flattening the representations for processing by the fully connected part of the network. :cite:t:`Simonyan.Zisserman.2014` described several other variants of VGG. In fact, it has become the norm to propose *families* of networks with different speed–accuracy trade-off when introducing a new architecture. Training -------- Since VGG-11 is computationally more demanding than AlexNet we construct a network with a smaller number of channels. This is more than sufficient for training on Fashion-MNIST. The model training process is similar to that of AlexNet in :numref:`sec_alexnet`. Again observe the close match between validation and training loss, suggesting only a small amount of overfitting. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01) trainer = d2l.Trainer(max_epochs=10, num_gpus=1) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) model.apply_init([next(iter(data.get_dataloader(True)))[0]], d2l.init_cnn) trainer.fit(model, data) .. figure:: output_vgg_4a7574_63_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01) trainer = d2l.Trainer(max_epochs=10, num_gpus=1) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) trainer.fit(model, data) .. figure:: output_vgg_4a7574_66_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01) trainer = d2l.Trainer(max_epochs=10, num_gpus=1) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) trainer.fit(model, data) .. figure:: output_vgg_4a7574_69_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python trainer = d2l.Trainer(max_epochs=10) data = d2l.FashionMNIST(batch_size=128, resize=(224, 224)) with d2l.try_gpu(): model = VGG(arch=((1, 16), (1, 32), (2, 64), (2, 128), (2, 128)), lr=0.01) trainer.fit(model, data) .. figure:: output_vgg_4a7574_72_0.svg .. raw:: html
.. raw:: html
Summary ------- One might argue that VGG is the first truly modern convolutional neural network. While AlexNet introduced many of the components of what make deep learning effective at scale, it is VGG that arguably introduced key properties such as blocks of multiple convolutions and a preference for deep and narrow networks. It is also the first network that is actually an entire family of similarly parametrized models, giving the practitioner ample trade-off between complexity and speed. This is also the place where modern deep learning frameworks shine. It is no longer necessary to generate XML configuration files to specify a network but rather, to assemble said networks through simple Python code. More recently ParNet :cite:`Goyal.Bochkovskiy.Deng.ea.2021` demonstrated that it is possible to achieve competitive performance using a much more shallow architecture through a large number of parallel computations. This is an exciting development and there is hope that it will influence architecture designs in the future. For the remainder of the chapter, though, we will follow the path of scientific progress over the past decade. Exercises --------- 1. Compared with AlexNet, VGG is much slower in terms of computation, and it also needs more GPU memory. 1. Compare the number of parameters needed for AlexNet and VGG. 2. Compare the number of floating point operations used in the convolutional layers and in the fully connected layers. 3. How could you reduce the computational cost created by the fully connected layers? 2. When displaying the dimensions associated with the various layers of the network, we only see the information associated with eight blocks (plus some auxiliary transforms), even though the network has 11 layers. Where did the remaining three layers go? 3. Use Table 1 in the VGG paper :cite:`Simonyan.Zisserman.2014` to construct other common models, such as VGG-16 or VGG-19. 4. Upsampling the resolution in Fashion-MNIST eight-fold from :math:`28 \times 28` to :math:`224 \times 224` dimensions is very wasteful. Try modifying the network architecture and resolution conversion, e.g., to 56 or to 84 dimensions for its input instead. Can you do so without reducing the accuracy of the network? Consult the VGG paper :cite:`Simonyan.Zisserman.2014` for ideas on adding more nonlinearities prior to downsampling. .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html
`Discussions `__ .. raw:: html
.. raw:: html