Computer vision

Knowledge Distillation: Principles & Algorithms [+Applications]

19 min read

Jul 21, 2022

Knowledge distillation in machine learning refers to transferring knowledge from a teacher to a student model. Learn about techniques for knowledge distillation.

Rohit Kundu

Rohit Kundu

The past few years have seen dramatic improvements in visual recognition systems driven by deeper and larger convolutional network architectures. However, the large computational complexity of these architectures has limited their use in many downstream applications.

For example, Ensemble Learning methods are popular for enhancing a model’s performance, but such techniques increase the computational requirement at least several times. As such, there has been a lot of recent research on achieving the same or similar accuracy with smaller models, one of which is distilling knowledge from larger networks into smaller ones.

Knowledge Distillation is a general-purpose technique that, at first glance, is widely applicable and complements all other ways of compressing neural networks. The key idea is to use soft probabilities (or ‘logits’) of a larger “teacher network” to supervise a smaller “student network,” in addition to the available class labels. These soft probabilities reveal more information than the class labels alone and can purportedly help the student network learn better.

Here’s what we’ll cover:

  • What is Knowledge Distillation?

  • How does Knowledge Distillation work?

  • Algorithms and principles

  • Practical applications

  • Benefits and limitations

A data labeling tool where a medical image is being labeled as Basophil Cell

Data labeling

Data labeling platform

Get started today

A data labeling tool where a medical image is being labeled as Basophil Cell

Data labeling

Data labeling platform

Get started today

Ready to streamline AI product deployment right away? Check out:

What is Knowledge Distillation?

Knowledge Distillation is a process of condensing knowledge from a complex model into a simpler one. It originates from Machine Learning, where the goal is to create models that can learn from data and make predictions. Early applications of Knowledge Distillation focused on creating smaller, more efficient models that could be deployed on devices with limited resources.

High-level view of Knowledge Distillation. Image by the author.

The key idea in Knowledge Distillation is to make a smaller, less complex model mimick a large complex model to generalize on data. If the complex model generalizes well because, for example, it is the average of an ensemble of different models, a small model trained to generalize in the same way (via Knowledge Distillation) will typically do much better on the test data than the same small model trained in the classical Deep Learning way on the same training set.

The basic foundation of compressing the knowledge of a complex framework (like an ensemble of models) in a single neural network was laid out by Bucila et al. in this paper in 2006. In the paper, the authors achieved model compression by training a smaller neural network on a vast amount of pseudo-data labeled by a complex ensemble model. The optimization problem was to match the compressed model's soft labels to the ensemble's logit.

Distillation of knowledge means that knowledge is transferred from the teacher network to the student network through a loss function where the optimization target is to match the class-wise probability distribution of the student network to the probability output by the teacher. This concept of model compression was generalized, and the concept of distillation was formulated in 2015 by Hinton et al. in this paper. We will dive deeper into the algorithmic background in the sections below.

Why do we need Knowledge Distillation?

Often the best performing Supervised Learning models are ensembles of several large base models. However, the space required to store such a framework of models and the time required to execute them at run-time for testing prohibits their usage in applications where the test dataset is very large or where the storage capabilities and computational abilities are limited.

Knowledge Distillation is a form of model compression that allows a relatively simple model to perform tasks almost as accurately as a very complex model. A pre-trained “teacher” model transfers its knowledge to a “student” model. The complex model has a higher knowledge capacity, but it may not be fully utilized for each task. Even if the complex model exploits only a small part of its knowledge in a task, evaluation is computationally expensive. Distilling knowledge to a smaller student model that is tailored to a specific task is more efficient.

For example, a model trained on the 14 billion images, 100 classes, ImageNet dataset, when used only as a cat-dog binary classifier for some application, is a waste of computational resources. Instead, using this model as a teacher to distill knowledge into a simpler model with only a few layers, tailored for being a cat-dog classifier, makes more sense.

Knowledge Distillation has allowed simple models to run on smaller hardware devices like mobile phones without much loss of performance compared to larger models.

How does Knowledge Distillation work?

Let us discuss in brief the main ideas behind Knowledge Distillation. We will look into in-depth algorithms in a later section.

One way a smaller model (student) can be employed to mimic the generalization ability of a complex model (teacher) is to use the class prediction probabilities generated by the teacher network as “soft targets” for training the student model. Generally, the same training set is used for this transfer of knowledge, but a separate “transfer set” of data can also be used to achieve the same.

When the teacher model is a large ensemble of simpler models, we can use the arithmetic or geometric mean of their individual predictive distributions as the soft targets. When the soft targets have high entropy, they provide much more information per training case than hard targets and much less variance in the gradient between training cases, so the student model can often be trained on much less data than the original teacher model and using a much higher learning rate.

A high-level overview of the Knowledge Distillation process has been depicted above. The pre-trained teacher model outputs a class prediction probability. When trained on the same data, the student model outputs a class probability distribution of its own. Using a “distillation loss” (which we will discuss later in this article), the probability distribution of the student model is pushed towards the distribution by the teacher model. Further, like in normal deep model training, the hard labels (prediction classes of the samples) are used along with the true class labels to compute the cross-entropy loss. These two losses together train the student model.

Algorithms and Principles behind Knowledge Distillation

In Knowledge Distillation, knowledge types, distillation strategies, and teacher-student architectures all play a crucial role in student network learning. Let us discuss these next.

Knowledge

There are three types of Knowledge depending on how the information is gathered from the deep teacher model, an illustration of which is shown below.

Response-based

In Response-based Knowledge systems, the information is obtained from the output layer of the teacher model. In such models, the student model is expected to mimic the logits (class probabilities) of the teacher model predictions. This is the most popular genre of Knowledge system.

The distillation loss in such cases involves computing a divergence metric between the logits of the teacher and student models. Kullback-Leibler Divergence is popularly used in response-based Knowledge Distillation methods.

For example, Response-based Knowledge has been utilized for Knowledge Distillation to address a multi-class object detection in this paper, where the teacher model response contains the logits together with the offset of a bounding box. The learning framework adapted is shown below.

Source: Paper

Feature-based

Each layer in a deep network learns different levels of feature representation with increasing abstraction. Thus, the feature maps from the intermediate layers of a network can be used as the knowledge from the teacher to train the student model. Such Knowledge Distillation methods are said to utilize feature-based Knowledge.

For example, this paper utilizes the feature maps from the penultimate layer of the teacher model to address the image classification problem. The authors argue that since there are visual similarities between images belonging to different classes, a one-hot label that assumes classes are independent cannot always accurately describe an image's real distribution over classes. The authors used a temperature factor along with an L2-normalized feature map of the penultimate teacher network layer for noise correction.

Source: Paper

Relation-based

Both response-based and feature-based Knowledge use the outputs of specific layers in the teacher model. Relation-based Knowledge further explores the relationships between different layers or data samples.

For example, this paper proposed a flow of solution process (FSP), defined by the Gram matrix between two layers. The FSP matrix summarizes the relations between pairs of feature maps. It is calculated using the inner products between features from two layers. Using the correlations between feature maps as the distilled knowledge, knowledge distillation via singular value decomposition was proposed to extract key information in the feature maps. The training procedure proposed by the authors is shown below.

Source: Paper

Distillation Schemes

Distillation schemes of teacher-student networks’ training schemes can be categorized based on whether the teacher model is updated simultaneously with the student model or not. Let us discuss these categories next.

Offline Distillation

In Offline Distillation, the teacher network is pre-trained and then frozen. While the student network gets trained, the teacher model is not updated. Most previous Knowledge Distillation methods (including Hinton et al.’s base paper) work offline. Most research focuses on improving the knowledge transfer mechanism in this Distillation scheme, and less attention is given to the teacher network architecture.

One such Offline Distillation-based method was proposed in this , which developed a probabilistic method for knowledge transfer (using feature-based Knowledge). The proposed method matches the probability distribution of data in the feature space instead of their actual representation, which was traditionally used. Such a method allowed cross-modal knowledge transfer and transfer of knowledge from handcrafted feature extractors into neural networks.

Source: Paper

Online Distillation

In many cases, a large pre-trained model as the teacher may not be available as Offline Distillation methods assume. Thus, in such scenarios, the teacher and student networks are trained simultaneously—which is referred to as Online Distillation.

For example, this recent paper proposed an Online Mutual Knowledge Distillation method where the authors aim to fuse sub-networks features. The sub-networks (which are ensembled) and the fusion module are learned by mutually teaching each other via Knowledge Distillation (using response-based Knowledge). This process is depicted below.

Source: Paper 

Self Distillation

Conventional Knowledge Distillation has two problems:

  1. the choice of teacher models has a significant impact on the accuracy of student models, and the teacher with the highest accuracy is not the best teacher for distillation;

  2. the student models can not consistently achieve as high accuracy as teacher models do, which may lead to an unacceptable accuracy degradation in the inference period.

The Self Distillation method addresses these problems where the same network acts as the teacher and the student. Self Distillation first attaches several attention-based shallow classifiers after the intermediate layers of neural networks at different depths. Then, in the training period, the deeper classifiers are regarded as the teacher models. They are utilized to guide the training of student models by a divergence metric-based loss on the outputs and L2 loss on the feature maps. In the inference period, all of the additional shallow classifiers are dropped.

One such method is proposed in this paper, which addresses the Weakly-Supervised Object Detection problem. The authors developed a Comprehensive Attention Self-Distillation (CASD) approach that conducts Self Distillation on the WSOD network itself to enforce consistent spatial supervision on objects, such that the comprehensive attention is approximated simultaneously by multiple transformations and layers of the same image. That is, the Self Distillation enables instance-balanced and spatially consistent supervision, resulting in robust bounding box localization. The model framework is shown below.

Source: Paper

Distillation Algorithms

Different algorithms have been proposed in the literature to make the process of transferring knowledge from the teacher to the student model more efficient in complex settings. Let us look into these popular algorithms next.

Adversarial

Adversarial Learning has been successful in generative modeling based on which Generative Adversarial Networks (GANs) have been developed. Adversarial Learning has tractioned attention in Knowledge Distillation as well since GANs can be used to augment the existing training set for better network learning or to enable the teacher/student models to learn the data distribution better.

For example, this paper proposed an adversarial Knowledge Distillation method for Event Detection in Natural Language. In the teacher network training stage, the model learns the knowledge representations on the ground-truth annotations. The student model is then set up to imitate the behavior of the teacher. A discriminator model measures the similarity between the teacher and student models by examining their outputs. In contrast, the student model (which acts as the generator in a GAN) tries to fool the discriminator by producing outputs that appear to come from the teacher. The architecture diagram for the method is shown below.

Source: Paper

Multi-Teacher

Different teacher architectures can provide different types of knowledge, which, when distilled into a student model, can produce better predictions than the individual models. An ensemble of several heavy architectures is generally used for distillation, the soft label outputs of which are typically averaged (or weighted average) for providing supervision to the student.

Researchers used a two-teacher framework in this paper for Knowledge Distillation which distills knowledge through different strategies. The first teacher model transfers Feature-based Knowledge, which gives intermediate representation guidance to the student network through adversarial learning instead of L2 normalized loss-based guidance. The second teacher distills Response-based Knowledge for probability distribution-based guidance of the student. The schematic of the framework is depicted below.

Source: Paper

Cross-Modal

In practical scenarios, usually, data is available in multiple modalities. However, sometimes the data or the labels from some modalities may become missing (or corrupt/unusable). Thus, it is essential to transfer knowledge between modalities.

One example of multi-modal Knowledge Distillation is the framework proposed in this paper, where there are two heterogeneous teacher models—one for audio and another for image, and the student model is fed with video data. The aim here was to learn a compositional embedding that closes the cross-modal semantic gap and captures the task-relevant semantics, which facilitates pulling together representations across modalities by compositional Contrastive Learning. The learning framework is pictorially represented below.

Source: Paper

Graph-based

Sometimes just transferring the individual instance knowledge from the teacher network to the student may not help the student model understand the intra-data relationships. To explore this paradigm, recent methods proposed the usage of graphs as carriers of knowledge from the teacher or to use them for controlling the message passing of the teacher's knowledge.

For example, this paper proposes an Inter-Region Affinity Knowledge Distillation (IntRA-KD) framework to address the road segmentation problem and aims to transfer the knowledge on scene structure more effectively from a large teacher network to a small student model. A teacher model should have a better capability of learning discriminative features and capturing contextual information due to its larger capacity in comparison to the student model. So the scene structure information is represented as affinity graphs which are matched between the teacher and student models for the Knowledge Distillation process. The framework of IntRA-KD is shown below.

Source: Paper

Attention-based

Attention is a key aspect of humans’ visual experience, and it closely relates to perception. Attention mechanisms have been used in neural network architectures for Computer Vision and NLP to allow the framework to “attend” to an object to examine it with greater detail. Such attention mechanisms have also been used in Knowledge Distillation for better student model learning.

For example, this paper attempted to match the intermediate layers’ attention maps (by computing the Jacobian matrices) between the teacher and student models for knowledge distillation. They also proposed methods to match the attention maps of deep networks with arbitrary architectures, thus having different spatial dimensions. The framework proposed is illustrated below.

Source: Paper

Data-Free

The student model in a Knowledge Distillation framework performs optimally when it has access to the training data used to pre-train the teacher network. However, this might not always be available due to the volume of training data required (since the teacher is a complex network, more data is needed to train it) or privacy/confidentiality concerns. This is especially true in biomedical applications where the patient data used to train the teacher model cannot be released for the student model to train on. 

Thus, Data-Free Knowledge Distillation techniques emerged, which aim to generate synthetic data that mimics the data distribution of the training samples of the teacher model. One such method is the DeepInversion for Object Detection (DIODE) framework proposed in this paper, which deviates from traditional GAN-based methods for image synthesis and uses DeepInversion. DeepInversion optimizes a batch of images, starting from noise, by matching the statistics of deep feature distributions to those stored in the network's batch-normalization (BN) layers. The illustration of the DIODE framework is shown below.

Quantized

Most edge devices today require fixed-point inference for computation and power efficiency. Hence, quantizing the weights and activations of deep networks to be deployed to a certain bit-width becomes necessary to make them compatible with edge devices. Quantized Knowledge Distillation aims to train a low precision student model (like 2-bit or 8-bit) from a high precision teacher network (like the 32-bit floating point).

For example, this paper quantized the weights of the student to a limited set of integer levels and used fewer weights per layer. The authors attempt this in two ways. First (as shown below), they aim to leverage distillation loss during the training process by incorporating it into the training of the student network, whose weights are constrained to a limited set of levels. Alternatively, they attempt to converge to the optimal location of quantization points through stochastic gradient descent optimization.

Source: Paper

Lifelong

Human vision can recognize images of novel categories just after browsing a few images of these categories since it can explore not only explicit visual information about novel objects but also some external discriminative visual information from their prior knowledge. Lifelong Learning aims to train a network using similar concepts.

Source: Paper

One example of Lifelong Distillation is the image recognition framework proposed in this paper. The authors proposed a few-shot learning setup for image recognition that jointly incorporates visual feature learning, knowledge inference, and classifier training into one framework. They used knowledge graphs to train the knowledge-based classifiers augmented from their vision-based counterparts. The framework proposed is shown below.

Source: Paper

Practical applications of Knowledge Distillation

As an effective technique for the compression and acceleration of deep neural networks, Knowledge Distillation has been widely used in different fields of Artificial Intelligence. Let us discuss some of these practical applications next.

Computer Vision

Knowledge Distillation has been used extensively in various Computer Vision tasks over the years—video and image segmentation and classification, human action recognition, pose estimation, etc.

For example, this paper proposes the CogniNet model that uses Knowledge Distillation to aid a BiLSTM (Bidirectional Long-Short Term Memory) model being trained to classify brain EEG signals using pre-trained vision models. Thus, a state-of-the-art deep convolutional vision network teaches a deep recurrent network to classify brain signals correctly. The student model does not depend on visual cues during inference. The CogniNet framework is shown below.

Source: Paper

Visual Question Answering

Visual Question Answering or VQA deals with the problem where a model needs to find an answer to a question about an input image. This is a highly challenging problem because VQA models deal with various recognition tasks simultaneously within a unified framework, which requires understanding the local and global context of an image and a natural text-based question. A VQA model thus should have diverse reasoning capabilities to capture appropriate information from input images and questions.

Examples of VQA. Source: Paper

In VQA, Multiple Choice Learning is a framework where examples are typically assigned to a subset of models with the highest accuracy. Thus, each model is expected to be specialized to certain types of examples. Mun et al. proposed an MCL-KD (Multiple Choice Learning with Knowledge Distillation) framework, which learns models to be specialized to a subset of tasks. The authors argue that the specialized models have the potential to outperform the models generalized on all tasks since the models trained by MCL achieve higher oracle accuracy: at least one of the models predicts correctly for each example.

Source: Paper

The authors use knowledge distillation such that specialized models are learned to predict the ground-truth answers while non-specialized ones are trained to preserve the representations of the corresponding base models. The overall framework of their approach is shown above.

NLP

Natural Language Processing (NLP) deals with AI systems that can comprehend natural human language (text in most cases). Knowledge Distillation has vast applications in NLP since conventional language models such as BERT are very time and resource-consuming with complex, cumbersome structures.

For example, this paper proposed a Knowledge Distillation-based method for extending existing sentence embedding models to new languages, thus creating multilingual versions from previously monolingual models. The authors use an English Sentence-BERT as the teacher model and XLM-R as the student model. The training procedure adopted in this paper is shown below.

Source: Paper

Speech Recognition

Spoken language identification is an essential task in multilingual applications like automated speech translation systems. This paper proposes one such framework for language identification in short-duration speech.

Unlike in most Knowledge Distillation methods, the authors here refrain from using a fixed pre-trained teacher model and instead develop an interactive teacher-student framework to improve teacher-student learning by adjusting the teacher model with reference to the performance of the student model. The learning mechanism adopted by the authors is shown below.

Source: Paper

Recommendation Systems

External knowledge, such as user reviews or product images, has been widely incorporated into modern recommender systems. Such knowledge can reveal comprehensive user/item properties and enhance the recommendation performance as well as make the recommendations more explainable.

For example, this paper developed a generalized distillation framework to have a runtime-efficient model during the testing time in a recommender system that uses external user reviews. The authors employed a complex CNN model as the teacher model, which is only used in the training phase.

Source: Paper

Neural Architecture Search (NAS) is a field in Automatic Machine Learning (AutoML) that deals with the automatic identification of a Deep Learning model through an optimization algorithm.

This paper, for example, proposed a Knowledge Distillation framework named “DNA” that distills the neural architecture from an existing architecture. Different blocks of the existing architecture have different knowledge in extracting different patterns of an image. Thus, the authors use a block-wise representation of existing models to supervise the architecture search. The “DNA” framework is shown below.

Source: Paper

Knowledge Distillation: Benefits and Limitations

Let’s now have a quick look at the pros and cons of knowledge distillation:


Key Takeaways

The modern deep networks are as computationally heavy as they are powerful in solving tasks. However, deployment of such models in edge or mobile devices is impossible without model compression. Knowledge Distillation has revolutionized this field, allowing efficient, lightweight models to retain the performance of the heavy models.

Several different schemes exist for Knowledge Distillation, each with its pros and cons relative to the application scenario. It has been successfully implemented in numerous domains like Computer Vision, NLP, NAS, etc., but it is not perfect yet. Researchers are still trying to address the challenges associated with Knowledge Distillation, like the compatibility issues of teacher-student models and the ability to integrate Knowledge Distillation with other learning schemes.

Generative AI tool that turns a pitch deck into structured information from unstructured input

Data extraction powered by AI

Automate data extraction

Get started today

Generative AI tool that turns a pitch deck into structured information from unstructured input

Data extraction powered by AI

Automate data extraction

Get started today

Rohit Kundu

Rohit Kundu

Rohit Kundu

Rohit Kundu

Rohit Kundu is a Ph.D. student in the Electrical and Computer Engineering department of the University of California, Riverside. He is a researcher in the Vision-Language domain of AI and published several papers in top-tier conferences and notable peer-reviewed journals.

Next steps

Label videos with V7.

Rewind less, achieve more.

Try our free tier or talk to one of our experts.

Next steps

Label videos with V7.

Rewind less, achieve more.