Learning to Self-Train for Semi-Supervised Few-Shot Classification
Few-shot classification (FSC) is challenging due to the scarcity of labeled training data (e.g. only one labeled data point per class). Meta-learning has shown to achieve promising results by learning to initialize a classification model for FSC. In this paper we propose a novel semi-supervised meta-learning method called learning to self-train (LST) that leverages unlabeled data and specifically meta-learns how to cherry-pick and label such unsupervised data to further improve performance. To this end, we train the LST model through a large number of semi-supervised few-shot tasks. On each task, we train a few-shot model to predict pseudo labels for unlabeled data, and then iterate the self-training steps on labeled and pseudo-labeled data with each step followed by fine-tuning. We additionally learn a soft weighting network (SWN) to optimize the self-training weights of pseudo labels so that better ones can contribute more to gradient descent optimization. We evaluate our LST method on two ImageNet benchmarks for semi-supervised few-shot classification and achieve large improvements over the state-of-the-art method.
A novel self-training strategy that prevents the model from drifting due to label noise and enables robust recursive training.
A novel meta-learned cherry-picking method that optimizes the weights of pseudo labels particularly for fast and efficient self-training.
Extensive experiments on two versions of ImageNet benchmarks – miniImageNet and tieredImageNet, in which our method achieves top performance.
Our Method: LST
Fig 1. The pipeline of the proposed LST method on a single (2-class, 3-shot) task.
Fig 2. Outer-loop and inner-loop training procedures in our LST method. The inner loop in the red box contains the m steps of re-training and T − m steps of fine-tuning. In recursive training, the fine-tuned θT replaces the initial MTL learned θT for the pseudo-labeling at the next stage.
Table 1. The 5-way, 1-shot and 5-shot classification accuracy (%) on miniImageNet and tieredImageNet datasets.
Please cite our paper if it is helpful to your work: