Capsule networks (CapsNets) are a hot new neural net architecture that may well have a profound impact on deep learning, in particular for computer vision. Wait, isn't computer vision pretty much solved already? Haven't we all seen fabulous examples of convolutional neural networks (CNNs) reaching super-human level in various computer vision tasks, such as classification, localization, object detection, semantic segmentation or instance segmentation (see Figure 1)?
Well, yes, we’ve seen fabulous CNNs, but:
Artificial Intelligence Conference
- They were trained on huge numbers of images (or they reused parts of neural nets that had). CapsNets can generalize well using much less training data.
- CNNs don’t handle ambiguity very well. CapsNets do, so they can perform well even on crowded scenes (although, they still struggle with backgrounds right now).
- CNNs lose plenty of information in the pooling layers. These layers reduce the spatial resolution (see Figure 2), so their outputs are invariant to small changes in the inputs. This is a problem when detailed information must be preserved throughout the network, such as in semantic segmentation. Today, this issue is addressed by building complex architectures around CNNs to recover some of the lost information. With CapsNets, detailed pose information (such as precise object position, rotation, thickness, skew, size, and so on) is preserved throughout the network, rather than lost and later recovered. Small changes to the inputs result in small changes to the outputs—information is preserved. This is called "equivariance." As a result, CapsNets can use the same simple and consistent architecture across different vision tasks.
- Finally, CNNs require extra components to automatically identify which object a part belongs to (e.g., this leg belongs to this sheep). CapsNets give you the hierarchy of parts for free.
Despite all their qualities, CapsNets are still far from perfect. Firstly, for now they don't perform as well as CNNs on larger images such as CIFAR10 or ImageNet. Moreover, they are computationally intensive, and they cannot detect two objects of the same type when they are too close to each other (this is called the "crowding problem," and it has been shown that humans have it, too). But the key ideas are extremely promising, and it seems likely that they just need a few tweaks to reach their full potential. After all, modern CNNs were invented in 1998, yet they only beat the state of the art on ImageNet in 2012, after a few tweaks.
So, what are CapsNets exactly?
In short, a CapsNet is composed of capsules rather than neurons. A capsule is a small group of neurons that learns to detect a particular object (e.g., a rectangle) within a given region of the image, and it outputs a vector (e.g., an 8-dimensional vector) whose length represents the estimated probability that the object is present[1], and whose orientation (e.g., in 8D space) encodes the object's pose parameters (e.g., precise position, rotation, etc.). If the object is changed slightly (e.g., shifted, rotated, resized, etc.) then the capsule will output a vector of the same length, but oriented slightly differently. Thus, capsules are equivariant.Much like a regular neural network, a CapsNet is organized in multiple layers (see Figure 4). The capsules in the lowest layer are called primary capsules: each of them receives a small region of the image as input (called its receptive field), and it tries to detect the presence and pose of a particular pattern, for example a rectangle. Capsules in higher layers, called routing capsules, detect larger and more complex objects, such as boats.
The primary capsule layer is implemented using a few regular convolutional layers. For example, in the paper, they use two convolutional layers that output 256 6x6 features maps containing scalars. They reshape this output to get 32 6x6 maps containing 8-dimensional vectors. Finally, they use a novel squashing function to ensure these vectors have a length between 0 and 1 (to represent a probability). And that's it: this gives the output of the primary capsules.
The capsules in the next layers also try to detect objects and their pose, but they work very differently, using an algorithm called routing by agreement. This is where most of the magic of CapsNets lies. Let's look at an example.
Suppose there are just two primary capsules: one rectangle capsule and one triangle capsule, and suppose they both detected what they were looking for. Both the rectangle and the triangle could be part of either a house or a boat (see Figure 5). Given the pose of the rectangle, which is slightly rotated to the right, the house and the boat would have to be slightly rotated to the right as well. Given the pose of the triangle, the house would have to be almost upside down, whereas the boat would be slightly rotated to the right. Note that both the shapes and the whole/part relationships are learned during training. Now notice that the rectangle and the triangle agree on the pose of the boat, while they strongly disagree on the pose of the house. So, it is very likely that the rectangle and triangle are part of the same boat, and there is no house.
Since we are now confident that the rectangle and triangle are part of the boat, it makes sense to send the outputs of the rectangle and triangle capsules more to the boat capsule, and less to the house capsule: this way, the boat capsule will receive more useful input signal, and the house capsule will receive less noise. For each connection, the routing-by-agreement algorithm maintains a routing weight (see Figure 6): it increases routing weight when there is agreement, and decreases it in case of disagreement.
The routing-by-agreement algorithm involves a few iterations of agreement-detection + routing-update (note that this happens for each prediction, not just once, and not just at training time). This is especially useful in crowded scenes: for example, in Figure 7, the scene is ambiguous because you could see an upside-down house in the middle, but this would leave the bottom rectangle and top triangle unexplained, so the routing-by-agreement algorithm will most likely converge to a better explanation: a boat at the bottom, and a house at the top. The ambiguity is said to be "explained away": the lower rectangle is best explained by the presence of a boat, which also explains the lower triangle, and once these two parts are explained away, the remaining parts are easily explained as a house.
And that’s it—you know the key ideas behind CapsNets! If you want more details, check out my two videos on CapsNets (one on the architecture and another on the implementation) and my commented TensorFlow implementation (Jupyter Notebook). Please don’t hesitate to comment on the videos, file issues on GitHub if you see any, or contact me on Twitter @aureliengeron. I hope you found this post useful!
No comments:
Post a Comment