Computer vision

Multi-Task Learning in ML: Optimization & Use Cases [Overview]

18 min read

Dec 2, 2022

Learn the basics of multi-task learning in deep neural networks. See its practical applications, when to use it, & how to optimize the multi-task learning process.

Rohit Kundu

Rohit Kundu

Deep Learning has been the de-facto choice for solving complex problems in Artificial Intelligence, like Computer Vision and Natural Language Processing, for several years now. However, most of these learning algorithms are trained for solving one single task, popularly called Single-Task Learning.

Single Task Learning methods aim at optimizing a single metric. Sometimes even an ensemble of models is used, all of which are dedicated to the same task. However, while focusing on one task, we may lose other important information not included in the problem, thus limiting or saturating the model performance.

It may be beneficial to jointly optimize a single model for multiple tasks. The sharing of representations between related tasks, a model will learn better decision boundaries on the original task. Such an approach is called “Multi-Task Learning.”

Here’s what we’ll cover:

  • What is a Multi-Task Learning model?

  • When should Multi-Task Learning be used?

  • Optimization Methods for Multi-Task Learning

  • Practical Applications of Multi-Task Learning

A data labeling tool where a medical image is being labeled as Basophil Cell

Data labeling

Data labeling platform

Get started today

A data labeling tool where a medical image is being labeled as Basophil Cell

Data labeling

Data labeling platform

Get started today

Explore other articles:

What is a Multi-Task Learning model?

As the name suggests, Multi-Task Learning refers to a single shared machine learning model that can perform multiple different (albeit related) tasks. Multi-Task Learning offers advantages like improved data efficiency, faster model convergence, and reduced model overfitting due to shared representations.

Multi-Task Learning resembles the mechanism of human learning more closely than Single-Task Learning because we humans often learn transferable skills. For example, learning to ride a bicycle makes it easy for someone to learn to ride a motorbike later on, which builds upon similar concepts of body balance. Moreover, learning to ride a motorcycle with gears helps to learn to drive a car with manual transmission faster. This is referred to as the inductive transfer of knowledge.

This mechanism of knowledge transfer is what allows humans to learn new concepts with only a few examples or no examples at all (which in Machine Learning is called “Few-Shot Learning” and “Zero-Shot Learning,” respectively).

Similarly, neural networks trained to perform a single task need to learn all the underlying concepts from scratch, which could have been obtained from other tasks as auxiliary information. This increases the overall computational cost of the single models.

For example, a model trained for image classification learns to classify samples by localizing specific objects in the images. Such a model, when used in conjunction with solving an object detection problem, already has the underlying feature localization capability, leading to faster model convergence. Thus, the same model is used to generalize over different tasks.

Read more: Overfitting vs Underfitting in Machine Learning [Differences] 

When should Multi-Task Learning be used?

In multi-task learning, multiple tasks are solved at the same time, typically with a single neural network. In addition to reduced inference time, solving a set of tasks jointly rather than independently can, in theory, have other benefits, such as improved prediction accuracy, increased data efficiency, and reduced training time.

Unfortunately, the quality of predictions is often observed to suffer when a network is tasked with making multiple predictions due to a phenomenon called “negative transfer.” In fact, multi-task performance can suffer so much that smaller independent networks are often superior. This may be because the tasks must be learned at different rates or because one task may dominate the learning leading to poor performance on other tasks. Furthermore, task gradients may interfere, and multiple summed losses may make the optimization landscape more challenging.

Nevertheless, when task objectives do not interfere much with each other, performance on both tasks can be maintained or even improved when jointly trained. Intuitively, this loss or gain of quality seems to depend on the relationship between the jointly trained tasks.

At the very least, the different tasks to be solved by a network need to have some inherent correlation. For example, a 2D keypoint detection task, where a model localizes important parts of an image, is correlated to a 2D edge detection task. Similarly, a depth estimation task is closely correlated to a Surface Normal Prediction task (deals with the task of predicting the surface orientation of the objects present inside a scene).

Ideally, a multi-task learning model will apply the information it learns during training on one task to decrease the loss on other tasks included in training the network. An exhaustive search for subsets is a possible solution to finding tasks that should be grouped together. However, this is extremely inefficient computationally.

To solve this challenge, this paper by the Google Brain team drew inspiration from meta-learning (like Zero-Shot Learning), i.e., the concept of “learning to learn” to build “Task Affinity Groupings” or TAG.

task affinity grouping overview

Source: Paper

TAG updates its model parameters with respect to only one single task and evaluates how this update affects the performance of the other tasks, and then undoes this update. This process is then repeated for every other task to gather information on how each task in the network would interact with any other task. Training then continues as normal by updating the model’s shared parameters with respect to every task in the network.

Using this strategy, the authors found that some tasks, when grouped together, in fact, exhibit beneficial relationships, while grouping some other set of tasks hurts overall model performance. Thus, multi-task learning is effective when tasks that are jointly optimized have high affinity and they cannot be randomly chosen into groups.

Optimization Methods for Multi-Task Learning

As we’ve already discussed, not all tasks can be claimed to be correlated and thus used in a Multi-Task Learning framework. The imbalance of datasets, the dissimilarity between tasks, negative transfer of knowledge, all pose challenges to Multi-Task Learning. Thus, the optimization of tasks is as important as selecting proper architectures for obtaining the best possible performance. Different strategies are used in the literature for optimization, which we will discuss next.

Loss Construction

This is one of the most intuitive ways of performing multi-task optimization—by balancing the individual loss functions defined for the separate tasks, using different weighting schemes. The model then optimizes the aggregated loss function as a way to learn multiple tasks at once.

Different loss weighing mechanisms have been used in the literature to aid the Multi-Task problem. For example, this paper assigned weights to the individual loss functions to be inversely proportional to the training set sizes of the respective tasks so as to not let a task having more data dominate the optimization.

Hard Parameter Sharing

In Hard Parameter Sharing, the hidden layers of the neural networks are shared while keeping some task-specific output layers. Sharing most of the layers for the related tasks reduces the chances of overfitting.

Illustration of Hard Parameter Sharing

According to this paper, the more tasks a shared model is learning simultaneously, the more it has to find a representation that captures all of the tasks, and the smaller the chance of overfitting on our original task.

HyperFace is a Multi-Task framework for Face Detection, Landmark Localization, Pose Estimation, and Gender Recognition. It extracts shallow features from the data and fuses them for passing through a single CNN model through hard parameter sharing and uses separate, fully connected layers for the different tasks. The architectural diagram of the HyperFace framework is shown below.

architecture of hyperface

Source: Paper

Soft Parameter Sharing

Hard parameter sharing performs well only if tasks are closely related. Therefore, new approaches have focused on learning the features that need to be shared between tasks. Soft parameter sharing refers to regularizing the distance between the parameters of the individual models to the overall training objective to encourage similar model parameters between the different tasks. It is commonly used in Multi-Task Learning as such regularization techniques are easy to implement.

Illustration of Soft Parameter Sharing

For example, the authors in this paper regularized the model parameters by using L2 normalization for an NLP task. This paper replaced the L2 normalization with the tensor trace norm of the tensor formed by stacking corresponding parameter vectors from different tasks.

Data Sampling

Machine Learning datasets often suffer from imbalanced data distributions. Multi-Task Learning further complicates this issue since training datasets of multiple tasks with potentially different sizes and data distributions are involved. The multi-task model has a greater probability of sampling data points from tasks with a larger available training dataset, leading to potential overfitting.

To handle this data imbalance, various data sampling techniques have been proposed to properly construct training datasets for the Multi-Task optimization problem. For example, researchers use “temperature” for data sampling, which is defined by:

An example of this is the method used in this paper, which tackles the problem of multilingual neural machine translation using a dynamic temperature-based sampling strategy for multilingual data. That is, with every epoch, the temperature coefficient is updated based on the model performance on the different tasks.

Read more: Training Data Quality: Why It Matters in Machine Learning 

Task Scheduling

Most Multi-Task Learning models make a decision on which task(s) to train on in an epoch in a very simple way, either training on all tasks at each step or randomly sampling a subset of tasks to train on. However, intelligently optimized task scheduling can significantly improve the overall model performance on all tasks.

For example, this paper proposes a multi-task multi-lingual model where tasks are scheduled according to the similarity between each task and the primary task. The authors consider both task similarity and the number of training samples available for the task to compute the choice for tasks to be trained in each step. The architecture developed by the authors is shown below.

multi-task multilingual architecture

Source: Paper

Another example is the approach adopted in this paper, where task scheduling has been employed in a multi-task Active Learning framework. The idea here is to assign task scheduling probabilities based on relative performance to a target level: the further the model is from the target performance on a given task, the more likely it is that the task will be scheduled. This is akin to the loss construction (specifically weighting individual losses) methods we saw before, which increase the loss weight of a task that exhibits slow learning. This task scheduling procedure is visually represented below.

visualization of active-sampling based multi-task learning framework

Source: Paper

Gradient Modulation

Most Multi-Task Learning approaches assume that the individual tasks used for the joint optimization are closely related. However, each task might not be closely related to all available tasks. In such cases, sharing information with an unrelated task might even hurt performance, a phenomenon known as “negative transfer.”

From an optimization perspective, negative transfer manifests as the presence of conflicting task gradients. When two tasks have gradient vectors that point in opposing directions, following the gradient for one task will decrease the performance of the other task. Following the average of the two gradients means that neither task sees the same improvement as in a single-task training setting. Thus, modulation of task gradients is a potential solution to this problem.

If a multi-task model is training on a collection of related tasks, then ideally, the gradients from these tasks should point in similar directions. One common way gradient modulation is done through adversarial training. For example, the Gradient Adversarial Training (GREAT) method enforces this condition explicitly by including an adversarial loss term in the multi-task model training that encourages gradients from different sources to have statistically indistinguishable distributions.

gradient adversarial training of neural networks

Source: Paper

Knowledge Distillation

Knowledge Distillation is a Machine Learning paradigm where knowledge is transferred from a larger computationally expensive model (called the “teacher” model) to a smaller model (called the “student” model) while retaining performance.

In Multi-Task Learning, the most common use of Knowledge Distillation is to distill the knowledge from several individual single-task “teacher” networks to a single multi-task “student” network. Interestingly, the performance of the student network has been shown to surpass that of the teacher networks in some domains, making knowledge distillation a desirable method not just for saving memory but also for increasing performance.

For example, this paper extended Knowledge Distillation for Multi-Task Learning to a Natural Language Understanding problem. In the training process, a few tasks are selected, and for each task, the authors train an ensemble of MT-DNN models (teachers) that outperform the best single model. Then, a single MT-DNN (student) model is trained via multi-task learning with the help of the teachers by using both the soft targets (classification probabilities predicted by the teacher models) and correct targets across different tasks. The training procedure is illustrated below.

architecture of a mtdnn model

Source: Paper

Practical Applications of Multi-Task Learning

Multi-Task Learning frameworks are used by researchers in all domains of Artificial Intelligence for developing resource-optimized models. Reliable multi-task models can be used in several application areas that have storage constraints, like biomedical facilities and in space probes. Let us look at the recent applications of such models in different realms of AI.

Computer Vision

Computer Vision is the branch of Artificial Intelligence that deals with problems like image classification, object detection, video retrieval, etc.—problems that are being tackled every day by our own smartphones, for example, when using fingerprints or face detection to unlock the screen or getting recommendations on YouTube. Multi-Task Learning benefits such tasks considerably both in terms of performance and resource efficiency.

Most single-task Computer Vision models are extremely computationally expensive, being very deep networks. Tackling multiple tasks with a multi-task network saves storage space and makes it easier to deploy in more real-world problems. Further, it helps alleviate the problem of requiring a large quantity of labeled data for model training.

The Cross-Stitch Network proposed in this paper is an example of a Multi-Task computer vision model. The authors tackle two sets of joint tasks for their analyses in the paper that they believe are closely related to each other—(1) Semantic Segmentation and Surface Normal Prediction since the segmentation boundaries also correspond to surface normal boundaries; (2) Object Detection and Attribute Prediction—since a region detected as “object” can also act as a positive sample for particular attributes (for example object “dog” can be a positive sample for attribute “four legs”).

cross-stich units illustration

Source: Paper

The Cross-Stitch Network automatically learns an optimal combination of shared and task-specific representations. It models shared representations using linear combinations and learns the optimal linear combinations for a given set of tasks. The authors integrate these cross-stitch units into a Convolution Network and provide an end-to-end learning framework.

integrating cross-stitch units into a Convolution Network

Source: Paper

Another architecture that differs from the traditional concepts of hard or soft parameter sharing is the AdaShare model, which develops a generic Multi-Task system and has been evaluated on the Semantic Segmentation+Surface Normal Prediction joint optimization task. The authors argue that an optimal Multi-Task Learning algorithm should not only achieve high accuracy on all tasks but also restrict the number of new network parameters as much as possible as the number of tasks grows. This is extremely important for many resource-limited applications, such as autonomous vehicles and mobile platforms, that would benefit from multi-task learning.

Read more: AI in Supply Chain and Logistics [20+ Practical Applications] 

AdaShare learns the feature-sharing pattern to achieve the best recognition accuracy while restricting the memory footprint as much as possible. The authors seek to learn the sharing pattern through a task-specific policy that selectively chooses which layers to execute for a given task in the multi-task network. In other words, the aim of AdaShare is to obtain a single network for multi-task learning that supports separate execution paths for different tasks. The concept of AdaShare is shown below.

adashare visualization

Source: Paper

Natural Language Processing

Natural Language Processing (NLP) is a branch of Artificial Intelligence that deals with natural human language prompt-like text (in any language), speech, etc. It encompasses several applications like sentence translation, image or video captioning, emotion detection, etc. Multi-Task Learning is extensively used in NLP problems to boost the performance of a primary task with auxiliary tasks.

For example, this paper addresses the problem of natural text classification using an adversarial Multi-Task Learning framework. The authors introduce the adversarial learning framework to distill learned features into task-specific and task-agnostic subspaces. Their architecture consists of a single shared Long Short-Term Memory (LSTM) layer and one task-specific LSTM layer per task. Once the input sentence from a task is passed through the shared LSTM layer and the task-specific LSTM layer, the two outputs are concatenated and used as the final features to perform model inference.

However, the features produced by the shared LSTM layer are also fed into the task discriminator. The task discriminator is trained to predict which task the original input sentence came from. The shared LSTM layer is then trained to jointly minimize the task loss along with the discriminator loss so that the features produced by the shared LSTM do not contain any task-specific information. The architecture proposed in the paper is shown below.

adversarial shared-private model

Source: Paper

Recommendation Systems

A personalized recommendation has become a major technique for helping users handle huge amounts of online content. To improve user experience, it is essential that the recommendation model accurately predicts users’ personal preferences for items.

An example of a Multi-Task recommender system is the Co-Attentive Multi-task Learning (CAML) model, where the authors enhance both the accuracy and explainability of explainable recommendations by tightly coupling the recommendation task and the explanation task. The authors designed the CAML architecture inspired by the cognitive processes of humans, which consists of three major sub-processes, which have been modeled as the encoder, selector, and decoder in their neural network.

In an explainable recommendation, the decoder is responsible for deciding the predicted rating (recommendation task) and generating the explanations (explanation task). The selector serves as the transferred cross-knowledge for both tasks. Further, CAML consists of a hierarchical co-attentive selector to effectively control the cross-knowledge transfer for both tasks. The selector models the deep level interactions between the users and items. In particular, it identifies reviews and concepts (cross-knowledge) that are important for the user-item pair based on co-attention. The architecture of CAML is shown below.

architecture of caml

Source: Paper

Reinforcement Learning

Reinforcement Learning is a paradigm of Deep Learning that falls somewhere in between the realms of supervised and unsupervised learning. In this learning scheme, an algorithm learns by making decisions through trial and error, where it is rewarded for correct decisions and penalized for wrong ones. It is commonly used for robotics applications.

Since many Reinforcement Learning problems do not necessarily involve complex perception, such as working with words or pixels, the architectural demand is not as high for many problems. Because of this, many deep networks for Reinforcement Learning are simple fully-connected, convolutional, or recurrent architectures. However, in the multi-task case, there are several instances of interesting works that leverage information between tasks to create improved architectures for Reinforcement Learning.

One such example is the Contextual Attention-based REpresentation learning or CARE model proposed in this paper, where the authors encode an input observation into multiple representations (corresponding to different skills or objects) using a mixture of encoders. The learning agent is then allowed to use the context to decide which representation(s) it uses for any given task, giving the agent fine-grained control over what information is shared across tasks, thus alleviating the problem of negative transfer we discussed before. The architectural diagram of CARE is shown below.

Source: Paper

Multimodal Learning

Multimodal Learning, as the name suggests, involves training models on multiple modalities of data (like audio, images, video, natural text, etc.), which may or may not be correlated. Multi-Task Learning is popularly used for implicitly injecting multimodal features into a single model.

An example of such an architecture is the model proposed in this paper, where the authors tackle the problem of automatic video captioning (the task of describing the content of a video using natural text). Video captioning models still suffer from the lack of sufficient temporal and logical supervision to be able to correctly capture the action sequence and story-dynamic language in videos, especially in the case of short clips.

The authors argue that video captioning benefits from incorporating complementary directed knowledge, both visual and textual. They address this problem by jointly training the task of video captioning with two related directed-generation tasks: a temporally-directed unsupervised video prediction task and a logically-directed language entailment generation task.

The unsupervised video prediction task, i.e., video-to-video generation, shares its encoder with the video captioning task’s encoder and helps it learn richer video representations that can predict their temporal context and action sequence. The entailment generation task, i.e., premise-to-entailment generation based on the image caption domain, shares its decoder with the video captioning decoder and helps it learn better video-entailed caption representations since the caption is essentially an entailment of the video, i.e., it describes subsets of objects and events that are logically implied by (or follow from) the entire video content. The architecture of the multi-task model proposed by the authors is shown below.

Source: Paper

Key Takeaways

Different models for tackling different tasks in Artificial Intelligence are resource inefficient. Much like humans are aided in performing different tasks by the context information available from other related tasks, Multi-Task Learning aims to tailor a single model to perform various tasks through joint optimization.

However, not all tasks can be grouped together to be tackled by the same model, and different optimization techniques have been proposed in the literature for task selection and scheduling to get the optimal performance (both in terms of accuracy and resource utilization).

We have seen the wide uses of Multi-Task Learning in the different fields of AI—Computer Vision, NLP, etc. However, there are still open problems that need to be solved. For example, can Multi-Task Learning be used in a fully Self-Supervised setting and produce optimal performance? Research is still being conducted to reduce the reliance on labeled data for model training while allowing a computationally efficient model to perform optimally on several different tasks.

A video labeling annotation tool where drone footage of a port inspection is being annotated

Video annotation

AI video annotation

Get started today

A video labeling annotation tool where drone footage of a port inspection is being annotated

Video annotation

AI video annotation

Get started today

Rohit Kundu

Rohit Kundu

Rohit Kundu

Rohit Kundu

Rohit Kundu is a Ph.D. student in the Electrical and Computer Engineering department of the University of California, Riverside. He is a researcher in the Vision-Language domain of AI and published several papers in top-tier conferences and notable peer-reviewed journals.

Next steps

Label videos with V7.

Rewind less, achieve more.

Try our free tier or talk to one of our experts.

Next steps

Label videos with V7.

Rewind less, achieve more.