Semi-prototypical networks
Now, we will see another interesting variant of prototypical networks called the semi-prototypical network. It deals with handling unlabeled examples. As we know, in the prototypical network, we compute the prototype of each class by taking the mean embedding of each class and then predict the class of query set by finding the distance between query points to the class prototypes.
Consider the case where our dataset contains some of the unlabeled data points: how do we compute the class prototypes of these unlabeled data points?
Let's say we have a support set, where x is the feature and y is the label, and a query set, . Along with these, we have one more set called the unlabeled set, R, where we have only unlabeled examples, .
So, what can we do with this unlabeled set?
First, we will compute the class prototype with all the examples given in the support set. Next, we use soft k-means and assign the class for unlabeled examples in R—that is, we assign the class for unlabeled examples in R by calculating the Euclidean distance between class prototypes and unlabelled example.
However, the problem with this approach is that, since we are using soft k-means, all of the unlabeled examples will belong to any of the class prototypes. Let us say, we have three classes in the support set, {lion, elephant, dog}; if our unlabeled example has a data point representing a cat, then it is not meaningful to place the cat in any of the class in the support set. So, instead of adding the data point to the existing class, we assign a new class for the unlabeled examples, called the distractor class.
But even with this approach, we will run into another problem because the distractor class itself will have high variance. For example, consider our unlabeled set, R, contains completely unrelated data points such as {cats, helicopter, bus, and others}; in this case, it is not suggested to keep all of the unlabeled examples in a single class called the distractor class, as they are already impure and unrelated to each other.
So, we remodel the distractor class as examples that are not within some threshold distance of all of the class prototypes. How can we compute this threshold? First, we compute the normalized distance between unlabeled examples in the unlabeled set R to all of the class prototypes. Next, we compute the threshold for each class prototype by feeding various statistics of the normalized distance, such as min, max, skewness, and kurtosis, to a neural network. Based on this threshold, we add or ignore the unlabeled examples to the class prototypes.