Prototypical networks
Prototypical networks are yet another simple, efficient, few shot learning algorithm. Like siamese networks, a prototypical network tries to learn the metric space to perform classification. The basic idea of prototypical networks is to create a prototypical representation of each class and classify a query point (that is, a new point) based on the distance between the class prototype and the query point.
Let's say we have a support set comprising images of lions, elephants, and dogs, as shown in the following diagram:
So, we have three classes: {lion, elephant, dog}. Now we need to create a prototypical representation for each of these three class. How can we build the prototype of these three classes? First, we will learn the embeddings of each data point using an embedding function. The embedding function, , can be any function that can be used to extract features. Since our input is an image, we can use the convolutional network as our embedding function, which will extract features from the input image:
Once we learn the embeddings of each data point, we take the mean embeddings of data points in each class and form the class prototype, as shown in the following diagram. So, a class prototype is basically the mean embeddings of data points in a class:
Similarly, when a new data point comes in, that is, a query point for which we want to predict the label, we will generate the embeddings for this new data point using the same embedding function that we used to create the class prototype—that is, we generate the embeddings for our query point using the convolutional network:
Once we have the embedding for our query point, we compare the distance between class prototype and query point embeddings to find which class the query point belongs to. We can use Euclidean distance as a measure for finding the distance between the class prototype and query points embeddings, as shown here:
After finding the distance between the class prototype and query point embeddings, we apply softmax to this distance and get the probabilities. Since we have three classes, that is, lion, elephant and dog, we will get three probabilities. So, the class that has the highest probability will be the class of our query point.
Since we want our network to learn from a few data points, that is, we want to perform few-shot learning, we train our network in the same way. So, we use episodic training—for each episode, we randomly sample a few data points from each class in our dataset and we call that a support set and train the network using only the support set, instead of the whole dataset. Similarly, we randomly sample a point from the dataset as a query point and try to predict its class. So, in this way, our network is trained how to learn from a smaller set of data points.
The overall flow of our prototypical network is shown in the following diagram. As you can see, first, we will generate the embeddings for all of the data points in our support set and build the class prototype by taking the mean embeddings of data points in a class. We also generate the embeddings for our query point. Then, we compute the distance between class prototype and query point embeddings. We use Euclidean distance as a distance measure. Then, we apply softmax to this distance and get the probabilities. As you can see in the following diagram since our query point is a lion, the probability for lion is high—0.9:
Prototypical networks are not only used for one-shot/few-shot learning but are also used in zero-shot learning. Consider the case where you have no data points per class, but you have the meta information containing a high-level description of each class. So, in those cases, we learn the embeddings from the meta information of each class to form the class prototype and then perform classification with the class prototype.