{"title": "Prototypical Networks for Few-shot Learning", "book": "Advances in Neural Information Processing Systems", "page_first": 4077, "page_last": 4087, "abstract": "We propose Prototypical Networks for the problem of few-shot classification, where a classifier must generalize to new classes not seen in the training set, given only a small number of examples of each new class. Prototypical Networks learn a metric space in which classification can be performed by computing distances to prototype representations of each class. Compared to recent approaches for few-shot learning, they reflect a simpler inductive bias that is beneficial in this limited-data regime, and achieve excellent results. We provide an analysis showing that some simple design decisions can yield substantial improvements over recent approaches involving complicated architectural choices and meta-learning. We further extend Prototypical Networks to zero-shot learning and achieve state-of-the-art results on the CU-Birds dataset.", "full_text": "Prototypical Networks for Few-shot Learning\n\nJake Snell\n\nUniversity of Toronto\u2217\n\nVector Institute\n\nKevin Swersky\n\nTwitter\n\nRichard Zemel\n\nUniversity of Toronto\n\nVector Institute\n\nCanadian Institute for Advanced Research\n\nAbstract\n\nWe propose Prototypical Networks for the problem of few-shot classi\ufb01cation, where\na classi\ufb01er must generalize to new classes not seen in the training set, given only\na small number of examples of each new class. Prototypical Networks learn a\nmetric space in which classi\ufb01cation can be performed by computing distances\nto prototype representations of each class. Compared to recent approaches for\nfew-shot learning, they re\ufb02ect a simpler inductive bias that is bene\ufb01cial in this\nlimited-data regime, and achieve excellent results. We provide an analysis showing\nthat some simple design decisions can yield substantial improvements over recent\napproaches involving complicated architectural choices and meta-learning. We\nfurther extend Prototypical Networks to zero-shot learning and achieve state-of-\nthe-art results on the CU-Birds dataset.\n\n1\n\nIntroduction\n\nFew-shot classi\ufb01cation [22, 18, 15] is a task in which a classi\ufb01er must be adapted to accommodate\nnew classes not seen in training, given only a few examples of each of these classes. A naive approach,\nsuch as re-training the model on the new data, would severely over\ufb01t. While the problem is quite\ndif\ufb01cult, it has been demonstrated that humans have the ability to perform even one-shot classi\ufb01cation,\nwhere only a single example of each new class is given, with a high degree of accuracy [18].\nTwo recent approaches have made signi\ufb01cant progress in few-shot learning. Vinyals et al. [32]\nproposed Matching Networks, which uses an attention mechanism over a learned embedding of the\nlabeled set of examples (the support set) to predict classes for the unlabeled points (the query set).\nMatching Networks can be interpreted as a weighted nearest-neighbor classi\ufb01er applied within an\nembedding space. Notably, this model utilizes sampled mini-batches called episodes during training,\nwhere each episode is designed to mimic the few-shot task by subsampling classes as well as data\npoints. The use of episodes makes the training problem more faithful to the test environment and\nthereby improves generalization. Ravi and Larochelle [24] take the episodic training idea further\nand propose a meta-learning approach to few-shot learning. Their approach involves training an\nLSTM [11] to produce the updates to a classi\ufb01er, given an episode, such that it will generalize well to\na test-set. Here, rather than training a single model over multiple episodes, the LSTM meta-learner\nlearns to train a custom model for each episode.\nWe attack the problem of few-shot learning by addressing the key issue of over\ufb01tting. Since data is\nseverely limited, we work under the assumption that a classi\ufb01er should have a very simple inductive\nbias. Our approach, Prototypical Networks, is based on the idea that there exists an embedding in\nwhich points cluster around a single prototype representation for each class. In order to do this,\nwe learn a non-linear mapping of the input into an embedding space using a neural network and\ntake a class\u2019s prototype to be the mean of its support set in the embedding space. Classi\ufb01cation\nis then performed for an embedded query point by simply \ufb01nding the nearest class prototype. We\n\n\u2217Initial work done while at Twitter.\n\n31st Conference on Neural Information Processing Systems (NIPS 2017), Long Beach, CA, USA.\n\n\f(a) Few-shot\n\n(b) Zero-shot\n\nFigure 1: Prototypical Networks in the few-shot and zero-shot scenarios. Left: Few-shot prototypes\nck are computed as the mean of embedded support examples for each class. Right: Zero-shot\nprototypes ck are produced by embedding class meta-data vk. In either case, embedded query points\nare classi\ufb01ed via a softmax over distances to class prototypes: p\u03c6(y = k|x) \u221d exp(\u2212d(f\u03c6(x), ck)).\n\nfollow the same approach to tackle zero-shot learning; here each class comes with meta-data giving\na high-level description of the class rather than a small number of labeled examples. We therefore\nlearn an embedding of the meta-data into a shared space to serve as the prototype for each class.\nClassi\ufb01cation is performed, as in the few-shot scenario, by \ufb01nding the nearest class prototype for an\nembedded query point.\nIn this paper, we formulate Prototypical Networks for both the few-shot and zero-shot settings.\nWe draw connections to Matching Networks in the one-shot setting, and analyze the underlying\ndistance function used in the model. In particular, we relate Prototypical Networks to clustering [4]\nin order to justify the use of class means as prototypes when distances are computed with a Bregman\ndivergence, such as squared Euclidean distance. We \ufb01nd empirically that the choice of distance\nis vital, as Euclidean distance greatly outperforms the more commonly used cosine similarity. On\nseveral benchmark tasks, we achieve state-of-the-art performance. Prototypical Networks are simpler\nand more ef\ufb01cient than recent meta-learning algorithms, making them an appealing approach to\nfew-shot and zero-shot learning.\n\n2 Prototypical Networks\n\n2.1 Notation\n\nIn few-shot classi\ufb01cation we are given a small support set of N labeled examples S =\n{(x1, y1), . . . , (xN , yN )} where each xi \u2208 RD is the D-dimensional feature vector of an example\nand yi \u2208 {1, . . . , K} is the corresponding label. Sk denotes the set of examples labeled with class k.\n\n2.2 Model\nPrototypical Networks compute an M-dimensional representation ck \u2208 RM , or prototype, of each\nclass through an embedding function f\u03c6 : RD \u2192 RM with learnable parameters \u03c6. Each prototype\nis the mean vector of the embedded support points belonging to its class:\n\n(cid:88)\n\nck =\n\n1\n|Sk|\n\n(xi,yi)\u2208Sk\n\nf\u03c6(xi)\n\n(1)\n\nGiven a distance function d : RM \u00d7 RM \u2192 [0, +\u221e), Prototypical Networks produce a distribution\nover classes for a query point x based on a softmax over distances to the prototypes in the embedding\nspace:\n\np\u03c6(y = k | x) =\n\n(2)\nLearning proceeds by minimizing the negative log-probability J(\u03c6) = \u2212 log p\u03c6(y = k | x) of the\ntrue class k via SGD. Training episodes are formed by randomly selecting a subset of classes from\nthe training set, then choosing a subset of examples within each class to act as the support set and a\n\n(cid:80)\nexp(\u2212d(f\u03c6(x), ck))\nk(cid:48) exp(\u2212d(f\u03c6(x), ck(cid:48)))\n\n2\n\nc1c2c3xv1v2v3c1c2c3x\fAlgorithm 1 Training episode loss computation for Prototypical Networks. N is the number of\nexamples in the training set, K is the number of classes in the training set, NC \u2264 K is the number\nof classes per episode, NS is the number of support examples per class, NQ is the number of query\nexamples per class. RANDOMSAMPLE(S, N ) denotes a set of N elements chosen uniformly at\nrandom from set S, without replacement.\nInput: Training set D = {(x1, y1), . . . , (xN , yN )}, where each yi \u2208 {1, . . . , K}. Dk denotes the\nsubset of D containing all elements (xi, yi) such that yi = k.\nOutput: The loss J for a randomly generated training episode.\nV \u2190 RANDOMSAMPLE({1, . . . , K}, NC)\nfor k in {1, . . . , NC} do\n\n(cid:46) Select class indices for episode\n\nSk \u2190 RANDOMSAMPLE(DVk , NS)\nQk \u2190 RANDOMSAMPLE(DVk \\ Sk, NQ)\nck \u2190 1\nNC\n\n(cid:88)\n\nf\u03c6(xi)\n\n(xi,yi)\u2208Sk\n\n(cid:46) Select support examples\n(cid:46) Select query examples\n(cid:46) Compute prototype from support examples\n\nend for\nJ \u2190 0\nfor k in {1, . . . , NC} do\n\nfor (x, y) in Qk do\n\nJ \u2190 J +\n\n1\n\nNCNQ\n\nend for\n\nend for\n\n(cid:34)\n\n(cid:35)\n\nexp(\u2212d(f\u03c6(x), ck(cid:48)))\n\n(cid:88)\n\nk(cid:48)\n\n(cid:46) Initialize loss\n\n(cid:46) Update loss\n\nd(f\u03c6(x), ck)) + log\n\nsubset of the remainder to serve as query points. Pseudocode to compute the loss J(\u03c6) for a training\nepisode is provided in Algorithm 1.\n\n2.3 Prototypical Networks as Mixture Density Estimation\n\nFor a particular class of distance functions, known as regular Bregman divergences [4], the Prototypi-\ncal Networks algorithm is equivalent to performing mixture density estimation on the support set\nwith an exponential family density. A regular Bregman divergence d\u03d5 is de\ufb01ned as:\n\n(cid:48)) = \u03d5(z) \u2212 \u03d5(z\n\n(cid:48)) \u2212 (z \u2212 z\n\n(cid:48))T\u2207\u03d5(z\n\n(cid:48)),\n\nd\u03d5(z, z\n\n(3)\nwhere \u03d5 is a differentiable, strictly convex function of the Legendre type. Examples of Bregman\ndivergences include squared Euclidean distance (cid:107)z \u2212 z(cid:48)(cid:107)2 and Mahalanobis distance.\nPrototype computation can be viewed in terms of hard clustering on the support set, with one cluster\nper class and each support point assigned to its corresponding class cluster. It has been shown [4]\nfor Bregman divergences that the cluster representative achieving minimal distance to its assigned\npoints is the cluster mean. Thus the prototype computation in Equation (1) yields optimal cluster\nrepresentatives given the support set labels when a Bregman divergence is used.\nMoreover, any regular exponential family distribution p\u03c8(z|\u03b8) with parameters \u03b8 and cumulant\nfunction \u03c8 can be written in terms of a uniquely determined regular Bregman divergence [4]:\n\np\u03c8(z|\u03b8) = exp{zT \u03b8 \u2212 \u03c8(\u03b8) \u2212 g\u03c8(z)} = exp{\u2212d\u03d5(z, \u00b5(\u03b8)) \u2212 g\u03d5(z)}\n\nConsider now a regular exponential family mixture model with parameters \u0393 = {\u03b8k, \u03c0k}K\n\nk=1:\n\nK(cid:88)\n\nK(cid:88)\n\np(z|\u0393) =\n\n\u03c0kp\u03c8(z|\u03b8k) =\n\n\u03c0k exp(\u2212d\u03d5(z, \u00b5(\u03b8k)) \u2212 g\u03d5(z))\n\nk=1\n\nk=1\n\nGiven \u0393, inference of the cluster assignment y for an unlabeled point z becomes:\n\np(y = k|z) =\n\n(cid:80)\n\u03c0k exp(\u2212d\u03d5(z, \u00b5(\u03b8k)))\nk(cid:48) \u03c0k(cid:48) exp(\u2212d\u03d5(z, \u00b5(\u03b8k)))\n\n(4)\n\n(5)\n\n(6)\n\nFor an equally-weighted mixture model with one cluster per class, cluster assignment inference\n(6) is equivalent to query class prediction (2) with f\u03c6(x) = z and ck = \u00b5(\u03b8k).\nIn this case,\n\n3\n\n\fPrototypical Networks are effectively performing mixture density estimation with an exponential\nfamily distribution determined by d\u03d5. The choice of distance therefore speci\ufb01es modeling assumptions\nabout the class-conditional data distribution in the embedding space.\n\n2.4 Reinterpretation as a Linear Model\n\nA simple analysis is useful in gaining insight into the nature of the learned classi\ufb01er. When we use\nEuclidean distance d(z, z(cid:48)) = (cid:107)z \u2212 z(cid:48)(cid:107)2, then the model in Equation (2) is equivalent to a linear\nmodel with a particular parameterization [21]. To see this, expand the term in the exponent:\n\n\u2212(cid:107)f\u03c6(x) \u2212 ck(cid:107)2 = \u2212f\u03c6(x)(cid:62)f\u03c6(x) + 2c\n\n(7)\nThe \ufb01rst term in Equation (7) is constant with respect to the class k, so it does not affect the softmax\nprobabilities. We can write the remaining terms as a linear model as follows:\n\nk f\u03c6(x) \u2212 c\n(cid:62)\n\n(cid:62)\nk ck\n\n2c\n\n(cid:62)\nk ck\n\n(cid:62)\nk ck = w\n\nk f\u03c6(x) \u2212 c\n(cid:62)\n\nk f\u03c6(x) + bk, where wk = 2ck and bk = \u2212c\n(cid:62)\n\n(8)\nWe focus primarily on squared Euclidean distance (corresponding to spherical Gaussian densities) in\nthis work. Our results indicate that Euclidean distance is an effective choice despite the equivalence\nto a linear model. We hypothesize this is because all of the required non-linearity can be learned\nwithin the embedding function. Indeed, this is the approach that modern neural network classi\ufb01cation\nsystems currently use, e.g., [16, 31].\n\n2.5 Comparison to Matching Networks\n\nPrototypical Networks differ from Matching Networks in the few-shot case with equivalence in the\none-shot scenario. Matching Networks [32] produce a weighted nearest neighbor classi\ufb01er given the\nsupport set, while Prototypical Networks produce a linear classi\ufb01er when squared Euclidean distance\nis used. In the case of one-shot learning, ck = xk since there is only one support point per class, and\nMatching Networks and Prototypical Networks become equivalent.\nA natural question is whether it makes sense to use multiple prototypes per class instead of just one.\nIf the number of prototypes per class is \ufb01xed and greater than 1, then this would require a partitioning\nscheme to further cluster the support points within a class. This has been proposed in Mensink\net al. [21] and Rippel et al. [27]; however both methods require a separate partitioning phase that is\ndecoupled from the weight updates, while our approach is simple to learn with ordinary gradient\ndescent methods.\nVinyals et al. [32] propose a number of extensions, including decoupling the embedding functions of\nthe support and query points, and using a second-level, fully-conditional embedding (FCE) that takes\ninto account speci\ufb01c points in each episode. These could likewise be incorporated into Prototypical\nNetworks, however they increase the number of learnable parameters, and FCE imposes an arbitrary\nordering on the support set using a bi-directional LSTM. Instead, we show that it is possible to\nachieve the same level of performance using simple design choices, which we outline next.\n\n2.6 Design Choices\n\nDistance metric Vinyals et al. [32] and Ravi and Larochelle [24] apply Matching Networks using\ncosine distance. However for both Prototypical Networks and Matching Networks any distance is\npermissible, and we found that using squared Euclidean distance can greatly improve results for both.\nFor Protypical Networks, we conjecture this is primarily due to cosine distance not being a Bregman\ndivergence, and thus the equivalence to mixture density estimation discussed in Section 2.3 does not\nhold.\n\nEpisode composition A straightforward way to construct episodes, used in Vinyals et al. [32] and\nRavi and Larochelle [24], is to choose Nc classes and NS support points per class in order to match\nthe expected situation at test-time. That is, if we expect at test-time to perform 5-way classi\ufb01cation\nand 1-shot learning, then training episodes could be comprised of Nc = 5, NS = 1. We have found,\nhowever, that it can be extremely bene\ufb01cial to train with a higher Nc, or \u201cway\u201d, than will be used\nat test-time. In our experiments, we tune the training Nc on a held-out validation set. Another\nconsideration is whether to match NS, or \u201cshot\u201d, at train and test-time. For Prototypical Networks,\nwe found that it is usually best to train and test with the same \u201cshot\u201d number.\n\n4\n\n\f2.7 Zero-Shot Learning\n\nZero-shot learning differs from few-shot learning in that instead of being given a support set of\ntraining points, we are given a class meta-data vector vk for each class. These could be determined\nin advance, or they could be learned from e.g., raw text [8]. Modifying Prototypical Networks to deal\nwith the zero-shot case is straightforward: we simply de\ufb01ne ck = g\u03d1(vk) to be a separate embedding\nof the meta-data vector. An illustration of the zero-shot procedure for Prototypical Networks as\nit relates to the few-shot procedure is shown in Figure 1. Since the meta-data vector and query\npoint come from different input domains, we found it was helpful empirically to \ufb01x the prototype\nembedding g to have unit length, however we do not constrain the query embedding f.\n\n3 Experiments\n\nFor few-shot learning, we performed experiments on Omniglot [18] and the miniImageNet version\nof ILSVRC-2012 [28] with the splits proposed by Ravi and Larochelle [24]. We perform zero-shot\nexperiments on the 2011 version of the Caltech UCSD bird dataset (CUB-200 2011) [34].\n\n3.1 Omniglot Few-shot Classi\ufb01cation\n\nOmniglot [18] is a dataset of 1623 handwritten characters collected from 50 alphabets. There are 20\nexamples associated with each character, where each example is drawn by a different human subject.\nWe follow the procedure of Vinyals et al. [32] by resizing the grayscale images to 28 \u00d7 28 and\naugmenting the character classes with rotations in multiples of 90 degrees. We use 1200 characters\nplus rotations for training (4,800 classes in total) and the remaining classes, including rotations, for\ntest. Our embedding architecture mirrors that used by Vinyals et al. [32] and is composed of four\nconvolutional blocks. Each block comprises a 64-\ufb01lter 3 \u00d7 3 convolution, batch normalization layer\n[12], a ReLU nonlinearity and a 2 \u00d7 2 max-pooling layer. When applied to the 28 \u00d7 28 Omniglot\nimages this architecture results in a 64-dimensional output space. We use the same encoder for\nembedding both support and query points. All of our models were trained via SGD with Adam [13].\nWe used an initial learning rate of 10\u22123 and cut the learning rate in half every 2000 episodes. No\nregularization was used other than batch normalization.\nWe trained Prototypical Networks using Eu-\nclidean distance in the 1-shot and 5-shot scenar-\nios with training episodes containing 60 classes\nand 5 query points per class. We found that it\nis advantageous to match the training-shot with\nthe test-shot, and to use more classes (higher\n\u201cway\u201d) per training episode rather than fewer.\nWe compare against various baselines, including\nthe Neural Statistician [7], Meta-Learner LSTM\n[24], MAML [9], and both the \ufb01ne-tuned and\nnon-\ufb01ne-tuned versions of Matching Networks\n[32]. We computed classi\ufb01cation accuracy for\nour models averaged over 1,000 randomly gen-\nerated episodes from the test set. The results\nare shown in Table 1 and to our knowledge are\ncompetitive with state-of-the-art on this dataset.\nFigure 2 shows a sample t-SNE visualization\n[20] of the embeddings learned by Prototypical\nNetworks. We visualize a subset of test char-\nacters from the same alphabet in order to gain\nbetter insight, despite the fact that classes in\nactual test episodes are likely to come from dif-\nferent alphabets. Even though the visualized\ncharacters are minor variations of each other,\nthe network is able to cluster the hand-drawn\ncharacters closely around the class prototypes.\n\nFigure 2: A t-SNE visualization of the embeddings\nlearned by Prototypical networks on the Omniglot\ndataset. A subset of the Tengwar script is shown\n(an alphabet in the test set). Class prototypes are\nindicated in black. Several misclassi\ufb01ed characters\nare highlighted in red along with arrows pointing\nto the correct prototype.\n\n5\n\n\fTable 1: Few-shot classi\ufb01cation accuracies on Omniglot. \u2217Uses non-standard train/test splits.\n\nModel\n\nDist.\n\nFine Tune\n\n5-way Acc.\n\n1-shot\n\n5-shot\n\n20-way Acc.\n\n1-shot\n\n5-shot\n\nMATCHING NETWORKS [32]\nMATCHING NETWORKS [32]\nNEURAL STATISTICIAN [7]\nMAML [9]\u2217\nPROTOTYPICAL NETWORKS (OURS)\n\nCosine\nCosine\n\n-\n-\n\nEuclid.\n\nN\nY\nN\nN\nN\n\n98.1% 98.9% 93.8% 98.5%\n97.9% 98.7% 93.5% 98.7%\n98.1% 99.5% 93.2% 98.1%\n98.7% 99.9% 95.8% 98.9%\n98.8% 99.7% 96.0% 98.9%\n\nTable 2: Few-shot classi\ufb01cation accuracies on miniImageNet. All accuracy results are averaged over\n600 test episodes and are reported with 95% con\ufb01dence intervals. \u2217Results reported by [24].\n\nModel\nBASELINE NEAREST NEIGHBORS\u2217\nMATCHING NETWORKS [32]\u2217\nMATCHING NETWORKS FCE [32]\u2217\nMETA-LEARNER LSTM [24]\u2217\nMAML [9]\nPROTOTYPICAL NETWORKS (OURS)\n\nDist.\n\nFine Tune\n\n1-shot\n\n5-way Acc.\n\n5-shot\n\nCosine\nCosine\nCosine\n\n-\n-\n\nEuclid.\n\nN\nN\nN\nN\nN\nN\n\n28.86 \u00b1 0.54% 49.79 \u00b1 0.79%\n43.40 \u00b1 0.78% 51.09 \u00b1 0.71%\n43.56 \u00b1 0.84% 55.31 \u00b1 0.73%\n43.44 \u00b1 0.77% 60.60 \u00b1 0.71%\n48.70 \u00b1 1.84% 63.15 \u00b1 0.91%\n49.42 \u00b1 0.78% 68.20 \u00b1 0.66%\n\n3.2 miniImageNet Few-shot Classi\ufb01cation\n\nThe miniImageNet dataset, originally proposed by Vinyals et al. [32], is derived from the larger\nILSVRC-12 dataset [28]. The splits used by Vinyals et al. [32] consist of 60,000 color images of size\n84 \u00d7 84 divided into 100 classes with 600 examples each. For our experiments, we use the splits\nintroduced by Ravi and Larochelle [24] in order to directly compare with state-of-the-art algorithms\nfor few-shot learning. Their splits use a different set of 100 classes, divided into 64 training, 16\nvalidation, and 20 test classes. We follow their procedure by training on the 64 training classes and\nusing the 16 validation classes for monitoring generalization performance only.\nWe use the same four-block embedding architecture as in our Omniglot experiments, though here\nit results in a 1,600-dimensional output space due to the increased size of the images. We also\nuse the same learning rate schedule as in our Omniglot experiments and train until validation loss\nstops improving. We train using 30-way episodes for 1-shot classi\ufb01cation and 20-way episodes for\n5-shot classi\ufb01cation. We match train shot to test shot and each class contains 15 query points per\nepisode. We compare to the baselines as reported by Ravi and Larochelle [24], which include a simple\nnearest neighbor approach on top of features learned by a classi\ufb01cation network on the 64 training\nclasses. The other baselines are two non-\ufb01ne-tuned variants of Matching Networks (both ordinary and\nFCE) and the Meta-Learner LSTM. in the non-\ufb01ne-tuned setting because the \ufb01ne-tuning procedure\nas proposed by Vinyals et al. [32] is not fully described. As can be seen in Table 2, Prototypical\nNetworks achieves state-of-the-art by a wide margin on 5-shot accuracy.\nWe conducted further analysis, to determine the effect of distance metric and the number of training\nclasses per episode on the performance of Prototypical Networks and Matching Networks. To make\nthe methods comparable, we use our own implementation of Matching Networks that utilizes the\nsame embedding architecture as our Prototypical Networks. In Figure 3 we compare cosine vs.\nEuclidean distance and 5-way vs. 20-way training episodes in the 1-shot and 5-shot scenarios, with\n15 query points per class per episode. We note that 20-way achieves higher accuracy than 5-way\nand conjecture that the increased dif\ufb01culty of 20-way classi\ufb01cation helps the network to generalize\nbetter, because it forces the model to make more \ufb01ne-grained decisions in the embedding space. Also,\nusing Euclidean distance improves performance substantially over cosine distance. This effect is even\nmore pronounced for Prototypical Networks, in which computing the class prototype as the mean of\nembedded support points is more naturally suited to Euclidean distances since cosine distance is not\na Bregman divergence.\n\n6\n\n\fFigure 3: Comparison showing the effect of distance metric and number of classes per training\nepisode on 5-way classi\ufb01cation accuracy for both Matching Networks and Prototypical Networks\non miniImageNet. The x-axis indicates con\ufb01guration of the training episodes (way, distance, and\nshot), and the y-axis indicates 5-way test accuracy for the corresponding shot. Error bars indicate\n95% con\ufb01dence intervals as computed over 600 test episodes. Note that Matching Networks and\nPrototypical Networks are identical in the 1-shot case.\n\nTable 3: Zero-shot classi\ufb01cation accuracies on CUB-200.\n\nModel\n\nImage\nFeatures\n\n50-way Acc.\n\n0-shot\n\nFisher\nAlexNet\nAlexNet\n\nALE [1]\nSJE [2]\nSAMPLE CLUSTERING [19]\nSJE [2]\nGoogLeNet\nDS-SJE [25]\nGoogLeNet\nDA-SJE [25]\nGoogLeNet\nSYNTHESIZED CLASSIFIERS [6]\nGoogLeNet\nPROTOTYPICAL NETWORKS (OURS) GoogLeNet\nZHANG AND SALIGRAMA [36]\nVGG-19\n\n26.9%\n40.3%\n44.3%\n50.1%\n50.4%\n50.9%\n54.7%\n54.8%\n\n55.3% \u00b1 0.8\n\n3.3 CUB Zero-shot Classi\ufb01cation\n\nIn order to assess the suitability of our approach for zero-shot learning, we also run experiments on\nthe Caltech-UCSD Birds (CUB) 200-2011 dataset [34]. The CUB dataset contains 11,788 images of\n200 bird species. We closely follow the procedure of Reed et al. [25] in preparing the data. We use\ntheir splits to divide the classes into 100 training, 50 validation, and 50 test. For images we use 1,024-\ndimensional features extracted by applying GoogLeNet [31] to middle, upper left, upper right, lower\nleft, and lower right crops of the original and horizontally-\ufb02ipped image2. At test time we use only\nthe middle crop of the original image. For class meta-data we use the 312-dimensional continuous\nattribute vectors provided with the CUB dataset. These attributes encode various characteristics of\nthe bird species such as their color, shape, and feather patterns.\nWe learned a simple linear mapping on top of both the 1024-dimensional image features and the\n312-dimensional attribute vectors to produce a 1,024-dimensional output space. For this dataset we\nfound it helpful to normalize the class prototypes (embedded attribute vectors) to be of unit length,\nsince the attribute vectors come from a different domain than the images. Training episodes were\nconstructed with 50 classes and 10 query images per class. The embeddings were optimized via SGD\nwith Adam at a \ufb01xed learning rate of 10\u22124 and weight decay of 10\u22125. Early stopping on validation\nloss was used to determine the optimal number of epochs for retraining on the training plus validation\nset.\nTable 3 shows that we achieve state-of-the-art results when compared to methods utilizing attributes\nas class meta-data. We compare our method to variety of zero-shot learning methods, including other\nembedding approaches such as ALE [1], SJE [2], and DS-SJE/DA-SJE [25]. We also compare to a\nrecent clustering approach [19] which trains an SVM on a learned feature space obtained by \ufb01ne-\n\n2Features downloaded from https://github.com/reedscot/cvpr2016.\n\n7\n\n5-wayCosine5-wayEuclid.20-wayCosine20-wayEuclid.1-shot20%30%40%50%60%70%80%1-shot Accuracy (5-way)Matching / Proto. Nets5-wayCosine5-wayEuclid.20-wayCosine20-wayEuclid.5-shot20%30%40%50%60%70%80%5-shot Accuracy (5-way)Matching NetsProto. Nets\ftuning AlexNet [16]. The Synthesized Classi\ufb01ers approach of [6] is a manifold learning technique\nthat aligns the class meta-data space with the visual model space, and the method of Zhang and\nSaligrama [36] is a structured prediction approach trained on top of VGG-19 features [30]. Since\nZhang and Saligrama [36] is a randomized method, we include their reported error bars in Table 3.\nOur Protypical Networks outperform Synthesized Classi\ufb01ers and are within error bars of Zhang and\nSaligrama [36], while being a much simpler approach than either.\nWe also ran an additional set of zero-shot experiments with stronger class meta-data. We extracted\n1,024-dimensional meta-data vectors for each CUB-200 class using the pretrained Char CNN-RNN\nmodel of [25], then trained zero-shot Prototypical Networks using the same procedure described\nabove except we used a 512-dimensional output embedding, as chosen via validation accuracy. We\nobtained test accuracy of 58.3%, compared to the 54.0% accuracy obtained by DS-SJE [25] with\na Char CNN-RNN model. Moreover, our result exceeds the 56.8% accuracy attained by DS-SJE\nwith even stronger Word CNN-RNN class-metadata representations. Taken together, these zero-shot\nclassi\ufb01cation results demonstrate that our approach is general enough to be applied even when the\ndata points (images) are from a different domain relative to the classes (attributes).\n\n4 Related Work\n\nThe literature on metric learning is vast [17, 5]; we summarize here the work most relevant to\nour proposed method. Neighborhood Components Analysis (NCA) [10] learns a Mahalanobis\ndistance to maximize K-nearest-neighbor\u2019s (KNN) leave-one-out accuracy in the transformed space.\nSalakhutdinov and Hinton [29] extend NCA by using a neural network to perform the transformation.\nLarge margin nearest neighbor (LMNN) classi\ufb01cation [33] also attempts to optimize KNN accuracy\nbut does so using a hinge loss that encourages the local neighborhood of a point to contain other\npoints with the same label. The DNet-KNN [23] is another margin-based method that improves\nupon LMNN by utilizing a neural network to perform the embedding instead of a simple linear\ntransformation. Of these, our method is most similar to the non-linear extension of NCA [29] because\nwe use a neural network to perform the embedding and we optimize a softmax based on Euclidean\ndistances in the transformed space, as opposed to a margin loss. A key distinction between our\napproach and non-linear NCA is that we form a softmax directly over classes, rather than individual\npoints, computed from distances to each class\u2019s prototype representation. This allows each class to\nhave a concise representation independent of the number of data points and obviates the need to store\nthe entire support set to make predictions.\nOur approach is also similar to the nearest class mean approach [21], where each class is represented\nby the mean of its examples. This approach was developed to rapidly incorporate new classes into\na classi\ufb01er without retraining, however it relies on a linear embedding and was designed to handle\nthe case where the novel classes come with a large number of examples. In contrast, our approach\nutilizes neural networks to non-linearly embed points and we couple this with episodic training in\norder to handle the few-shot scenario. Mensink et al. [21] attempt to extend their approach to also\nperform non-linear classi\ufb01cation, but they do so by allowing classes to have multiple prototypes.\nThey \ufb01nd these prototypes in a pre-processing step by using k-means on the input space and then\nperform a multi-modal variant of their linear embedding. Prototypical Networks, on the other hand,\nlearn a non-linear embedding in an end-to-end manner with no such pre-processing, producing a\nnon-linear classi\ufb01er that still only requires one prototype per class. In addition, our approach naturally\ngeneralizes to other distance functions, particularly Bregman divergences.\nThe center loss proposed by Wen et al. [35] for face recognition is similar to ours but has two main\ndifferences. First, they learn the centers for each class as parameters of the model whereas we\ncompute protoypes as a function of the labeled examples within each episode. Second, they combine\nthe center loss with a softmax loss in order to prevent representations collapsing to zero, whereas we\nconstruct a softmax loss from our prototypes which naturally prevents such collapse. Moreover, our\napproach is designed for the few-shot scenario rather than face recognition.\nA relevant few-shot learning method is the meta-learning approach proposed in Ravi and Larochelle\n[24]. The key insight here is that LSTM dynamics and gradient descent can be written in effectively\nthe same way. An LSTM can then be trained to itself train a model from a given episode, with the\nperformance goal of generalizing well on the query points. MAML [9] is another meta-learning\napproach to few-shot learning. It seeks to learn a representation that is easily \ufb01t to new data with few\n\n8\n\n\fsteps of gradient descent. Matching Networks and Prototypical Networks can also be seen as forms\nof meta-learning, in the sense that they produce simple classi\ufb01ers dynamically from new training\nepisodes; however the core embeddings they rely on are \ufb01xed after training. The FCE extension to\nMatching Networks involves a secondary embedding that depends on the support set. However, in\nthe few-shot scenario the amount of data is so small that a simple inductive bias seems to work well,\nwithout the need to learn a custom embedding for each episode.\nPrototypical Networks are also related to the Neural Statistician [7] from the generative modeling\nliterature, which extends the variational autoencoder [14, 26] to learn generative models of datasets\nrather than individual points. One component of the Neural Statistician is the \u201cstatistic network\u201d\nwhich summarizes a set of data points into a statistic vector. It does this by encoding each point within\na dataset, taking a sample mean, and applying a post-processing network to obtain an approximate\nposterior over the statistic vector. Edwards and Storkey [7] test their model for one-shot classi\ufb01cation\non the Omniglot dataset by considering each character to be a separate dataset and making predictions\nbased on the class whose approximate posterior over the statistic vector has minimal KL-divergence\nfrom the posterior inferred by the test point. Like the Neural Statistician, we also produce a summary\nstatistic for each class. However, ours is a discriminative model, as be\ufb01ts our discriminative task of\nfew-shot classi\ufb01cation.\nWith respect to zero-shot learning, the use of embedded meta-data in Prototypical Networks resembles\nthe method of [3] in that both predict the weights of a linear classi\ufb01er. The DS-SJE and DA-SJE\napproach of [25] also learns deep multimodal embedding functions for images and class meta-data.\nUnlike ours, they learn using an empirical risk loss. Neither [3] nor [25] uses episodic training, which\nallows us to help speed up training and regularize the model.\n\n5 Conclusion\n\nWe have proposed a simple method called Prototypical Networks for few-shot learning based on the\nidea that we can represent each class by the mean of its examples in a representation space learned\nby a neural network. We train these networks to speci\ufb01cally perform well in the few-shot setting by\nusing episodic training. The approach is far simpler and more ef\ufb01cient than recent meta-learning\napproaches, and produces state-of-the-art results even without sophisticated extensions developed for\nMatching Networks (although these can be applied to Prototypical Networks as well). We show how\nperformance can be greatly improved by carefully considering the chosen distance metric, and by\nmodifying the episodic learning procedure. We further demonstrate how to generalize Prototypical\nNetworks to the zero-shot setting, and achieve state-of-the-art results on the CUB-200 dataset. A\nnatural direction for future work is to utilize Bregman divergences other than squared Euclidean\ndistance, corresponding to class-conditional distributions beyond spherical Gaussians. We conducted\npreliminary explorations of this, including learning a variance per dimension for each class. This did\nnot lead to any empirical gains, suggesting that the embedding network has enough \ufb02exibility on its\nown without requiring additional \ufb01tted parameters per class. Overall, the simplicity and effectiveness\nof Prototypical Networks makes it a promising approach for few-shot learning.\n\nAcknowledgements\n\nWe would like to thank Marc Law, Sachin Ravi, Hugo Larochelle, Renjie Liao, and Oriol Vinyals\nfor helpful discussions. This work was supported by the Samsung GRP project and the Canadian\nInstitute for Advanced Research.\n\nReferences\n[1] Zeynep Akata, Florent Perronnin, Zaid Harchaoui, and Cordelia Schmid. Label-embedding for attribute-\n\nbased classi\ufb01cation. In IEEE Computer Vision and Pattern Recognition, pages 819\u2013826, 2013.\n\n[2] Zeynep Akata, Scott Reed, Daniel Walter, Honglak Lee, and Bernt Schiele. Evaluation of output embed-\n\ndings for \ufb01ne-grained image classi\ufb01cation. In IEEE Computer Vision and Pattern Recognition, 2015.\n\n[3] Jimmy Ba, Kevin Swersky, Sanja Fidler, and Ruslan Salakhutdinov. Predicting deep zero-shot convolutional\nneural networks using textual descriptions. In International Conference on Computer Vision, pages 4247\u2013\n4255, 2015.\n\n9\n\n\f[4] Arindam Banerjee, Srujana Merugu, Inderjit S Dhillon, and Joydeep Ghosh. Clustering with bregman\n\ndivergences. Journal of Machine Learning Research, 6(Oct):1705\u20131749, 2005.\n\n[5] Aur\u00e9lien Bellet, Amaury Habrard, and Marc Sebban. A survey on metric learning for feature vectors and\n\nstructured data. arXiv preprint arXiv:1306.6709, 2013.\n\n[6] Soravit Changpinyo, Wei-Lun Chao, Boqing Gong, and Fei Sha. Synthesized classi\ufb01ers for zero-shot\n\nlearning. In IEEE Computer Vision and Pattern Recognition, pages 5327\u20135336, 2016.\n\n[7] Harrison Edwards and Amos Storkey. Towards a neural statistician. International Conference on Learning\n\nRepresentations, 2017.\n\n[8] Mohamed Elhoseiny, Babak Saleh, and Ahmed Elgammal. Write a classi\ufb01er: Zero-shot learning using\n\npurely textual descriptions. In International Conference on Computer Vision, pages 2584\u20132591, 2013.\n\n[9] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep\n\nnetworks. International Conference on Machine Learning, 2017.\n\n[10] Jacob Goldberger, Geoffrey E. Hinton, Sam T. Roweis, and Ruslan Salakhutdinov. Neighbourhood\n\ncomponents analysis. In Advances in Neural Information Processing Systems, pages 513\u2013520, 2004.\n\n[11] Sepp Hochreiter and J\u00fcrgen Schmidhuber. Long short-term memory. Neural Computation, 9(8):1735\u20131780,\n\n1997.\n\n[12] Sergey Ioffe and Christian Szegedy. Batch normalization: Accelerating deep network training by reducing\n\ninternal covariate shift. arXiv preprint arXiv:1502.03167, 2015.\n\n[13] Diederik Kingma and Jimmy Ba. Adam: A method for stochastic optimization. arXiv preprint\n\narXiv:1412.6980, 2014.\n\n[14] Diederik P. Kingma and Max Welling. Auto-encoding variational bayes. arXiv preprint arXiv:1312.6114,\n\n2013.\n\n[15] Gregory Koch. Siamese neural networks for one-shot image recognition. Master\u2019s thesis, University of\n\nToronto, 2015.\n\n[16] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. Imagenet classi\ufb01cation with deep convolutional\n\nneural networks. In Advances in Neural Information Processing Systems, pages 1097\u20131105, 2012.\n\n[17] Brian Kulis. Metric learning: A survey. Foundations and Trends in Machine Learning, 5(4):287\u2013364,\n\n2012.\n\n[18] Brenden M. Lake, Ruslan Salakhutdinov, Jason Gross, and Joshua B. Tenenbaum. One shot learning of\n\nsimple visual concepts. In CogSci, 2011.\n\n[19] Renjie Liao, Alexander Schwing, Richard Zemel, and Raquel Urtasun. Learning deep parsimonious\n\nrepresentations. Advances in Neural Information Processing Systems, 2016.\n\n[20] Laurens van der Maaten and Geoffrey Hinton. Visualizing data using t-sne. Journal of Machine Learning\n\nResearch, 9(Nov):2579\u20132605, 2008.\n\n[21] Thomas Mensink, Jakob Verbeek, Florent Perronnin, and Gabriela Csurka. Distance-based image classi\ufb01-\ncation: Generalizing to new classes at near-zero cost. IEEE Transactions on Pattern Analysis and Machine\nIntelligence, 35(11):2624\u20132637, 2013.\n\n[22] Erik G Miller, Nicholas E Matsakis, and Paul A Viola. Learning from one example through shared densities\n\non transforms. In IEEE Computer Vision and Pattern Recognition, volume 1, pages 464\u2013471, 2000.\n\n[23] Renqiang Min, David A Stanley, Zineng Yuan, Anthony Bonner, and Zhaolei Zhang. A deep non-linear\nfeature mapping for large-margin knn classi\ufb01cation. In IEEE International Conference on Data Mining,\npages 357\u2013366, 2009.\n\n[24] Sachin Ravi and Hugo Larochelle. Optimization as a model for few-shot learning. International Conference\n\non Learning Representations, 2017.\n\n[25] Scott Reed, Zeynep Akata, Bernt Schiele, and Honglak Lee. Learning deep representations of \ufb01ne-grained\n\nvisual descriptions. In IEEE Computer Vision and Pattern Recognition, 2016.\n\n[26] Danilo Jimenez Rezende, Shakir Mohamed, and Daan Wierstra. Stochastic backpropagation and approxi-\n\nmate inference in deep generative models. arXiv preprint arXiv:1401.4082, 2014.\n\n10\n\n\f[27] Oren Rippel, Manohar Paluri, Piotr Dollar, and Lubomir Bourdev. Metric learning with adaptive density\n\ndiscrimination. International Conference on Learning Representations, 2016.\n\n[28] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang,\nAndrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg, and Li Fei-Fei. Imagenet large\nscale visual recognition challenge. International Journal of Computer Vision, 115(3):211\u2013252, 2015.\n\n[29] Ruslan Salakhutdinov and Geoffrey E. Hinton. Learning a nonlinear embedding by preserving class\n\nneighbourhood structure. In AISTATS, pages 412\u2013419, 2007.\n\n[30] Karen Simonyan and Andrew Zisserman. Very deep convolutional networks for large-scale image recogni-\n\ntion. arXiv preprint arXiv:1409.1556, 2014.\n\n[31] Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, Dragomir Anguelov, Dumitru\nErhan, Vincent Vanhoucke, and Andrew Rabinovich. Going deeper with convolutions. In IEEE Computer\nVision and Pattern Recognition, pages 1\u20139, 2015.\n\n[32] Oriol Vinyals, Charles Blundell, Tim Lillicrap, Daan Wierstra, et al. Matching networks for one shot\n\nlearning. In Advances in Neural Information Processing Systems, pages 3630\u20133638, 2016.\n\n[33] Kilian Q Weinberger, John Blitzer, and Lawrence K Saul. Distance metric learning for large margin nearest\nneighbor classi\ufb01cation. In Advances in Neural Information Processing Systems, pages 1473\u20131480, 2005.\n\n[34] P. Welinder, S. Branson, T. Mita, C. Wah, F. Schroff, S. Belongie, and P. Perona. Caltech-UCSD Birds 200.\n\nTechnical Report CNS-TR-2010-001, California Institute of Technology, 2010.\n\n[35] Yandong Wen, Kaipeng Zhang, Zhifeng Li, and Yu Qiao. A discriminative feature learning approach for\n\ndeep face recognition. In European Conference on Computer Vision, pages 499\u2013515. Springer, 2016.\n\n[36] Ziming Zhang and Venkatesh Saligrama. Zero-shot recognition via structured prediction. In European\n\nConference on Computer Vision, pages 533\u2013548. Springer, 2016.\n\n11\n\n\f", "award": [], "sourceid": 2153, "authors": [{"given_name": "Jake", "family_name": "Snell", "institution": "University of Toronto"}, {"given_name": "Kevin", "family_name": "Swersky", "institution": "Google Brain"}, {"given_name": "Richard", "family_name": "Zemel", "institution": "University of Toronto"}]}