Transfer Learning
Last updated
Last updated
Just to recap, we would have several tasks each of which has Support and Query set. The tasks that are used for training has ground truth labels for all the examples in support and query set. Where as, the tasks that are used for validation or testing, has ground truth labels for Support set but not query set. We have to predict those labels for query set given support set samples with labels.
This idea is presented in 2019 and the references are given below. It basically has 2 stages namely Training and Fine-Tuning(test case).
The goal here is, given an abundant (in huge quantity) Base class labeled data, , and a small amount of Novel class(unseen during training ) labeled data, The model should learn quickly to classify those Novel classes with very few labeled examples.
We useto train a standard classification with feature extractor (), which is a typical CNN, and a classifier on top of it, . Here, is the weight matrix in the classifier. In general setup, the classifier can have many layers, here it is one single layer.
As all these support and query set has labels, we can use it for training both feature extractor and classifier with cross entropy loss as our loss function. This dataset is divided into training and validation to train the model.
For this we use Novel class ( unseen during training ) data,, to fine tune only the classifier weights. In the above figure, you can see that the feature extractor() is fixed. Only the classifier weights are learned with those few samples of Novel class data. We denote those weights as as shown in figure.
Here, training(fine-tuning) happens with data from support set and validated using query set data
The authors of this paper came up with two different models that can be trained this way. Baseline and Baseline++. Both has same feature extractor but a different classifier.
Here, is the output dimension from feature extractor(), and is the number of output classes. Remaining is self explanatory
Here, is the column vector and the weights in this column vector are connected to neuron in the classification layer, Where,and where,
For the training example, we calculate . We calculate similarities (cosine similarity) between and each of .Similarity scores,.These are normalized using softmax to get the probabilities.
The learned weight vectors for each class can be thought of prototypes or representational vectors for each class. It classifies the input to the class based on similarity score for each class. For example, if the input is most similar to , then the input is classified as class-2.
Easy to implement and train
For feature extractor, We can use some networks which are trained over large datasets like ImageNet.
Faster to train on GPU's as it is heavily parallelizable and train across multiple GPU's(ie, distributed training)
As we are training over large dataset, it is prone to overfitting.
There is a very good chance that the set of parameters (like optimizer, #epochs, learning rate etc.,.) might not work when we are fine tuning when we are dealing with Cross-Domain. It may so happen that the Tasks might come from different distribution ( novel class data ) and we might have to fine tune the classifier weights all-over again to get the desired performance. Again, It won't work if the data itself is a cross-domain like, few example from automobiles, and few examples from animals etc.
This requires huge amount of data in the beginning and it is highly impossible in few cases like when the data is collected in a robot exploration.
A CLOSER LOOK AT FEW-SHOT CLASSIFICATION and it's official implementation in GitHub (GitHub code )