Computer vision
A Step-by-Step Guide to Federated Learning in Computer Vision
11 min read
—
Feb 3, 2023
In this article, we’ll explore federated learning from the ground up, including its most common applications in machine learning.
Yesha Shastri
The way we train machine learning models has come a long way in recent years.
A conventional approach was to gather all data at a central server and use it to train the model. But this method, while easy, has raised concerns about data privacy, leaving a lot of valuable but sensitive data inaccessible.
To address this issue, AI models started to shift to a decentralized approach, and a new concept called "federated learning" has emerged.
In this article, we’ll explore federated learning from the ground up, including its most common applications in machine learning.
Here’s what we’ll cover:
What is federated learning?
How does federated learning work?
Types of federated learning
Federated learning frameworks
Challenges and limitations of federated learning
Real-life applications of federated learning
Let’s go!
What is federated learning?
Federated learning (often referred to as collaborative learning) is a decentralized approach to training machine learning models. It doesn’t require an exchange of data from client devices to global servers. Instead, the raw data on edge devices is used to train the model locally, increasing data privacy. The final model is formed in a shared manner by aggregating the local updates.
Here’s why federated learning is important.
Privacy: In contrast to traditional methods where data is sent to a central server for training, federated learning allows for training to occur locally on the edge device, preventing potential data breaches.
Data security: Only the encrypted model updates are shared with the central server, assuring data security. Additionally, secure aggregation techniques such as Secure Aggregation Principle allow the decryption of only aggregated results.
Access to heterogeneous data: Federated learning guarantees access to data spread across multiple devices, locations, and organizations. It makes it possible to train models on sensitive data, such as financial or healthcare data while maintaining security and privacy. And thanks to greater data diversity, models can be made more generalizable.
How does federated learning work?
A generic baseline model is stored at the central server. The copies of this model are shared with the client devices, which then train the models based on the local data they generate. Over time, the models on individual devices become personalized and provide a better user experience.
In the next stage, the updates (model parameters) from the locally trained models are shared with the main model located at the central server using secure aggregation techniques. This model combines and averages different inputs to generate new learnings. Since the data is collected from diverse sources, there is greater scope for the model to become generalizable.
Once the central model has been re-trained on new parameters, it’s shared with the client devices again for the next iteration. With every cycle, the models gather a varied amount of information and improve further without creating privacy breaches.
How does federated learning work? (source)
Types of federated learning
Let’s explore the commonly used strategies and algorithms used for federated learning.
Federated learning strategies
Centralized federated learning
Centralized federated learning requires a central server. It coordinates the selection of client devices in the beginning and gathers the model updates during training. The communication happens only between the central server and individual edge devices.
While this approach looks straightforward and generates accurate models, the central server poses a bottleneck problem—network failures can halt the complete process.
Centralized federated learning. The updates from multiple users (B) are passed to a global model, and the cycle repeats (source)
Decentralized federated learning
Decentralized federated learning does not require a central server to coordinate the learning. Instead, the model updates are shared only among the interconnected edge devices. The final model is obtained on an edge device by aggregating the local updates of the connected edge devices.
This approach prevents the possibility of a single-point failure; however, the model's accuracy is completely dependent on the network topology of the edge devices.
Decentralized federated learning (source)
Heterogeneous federated learning
Heterogeneous federated learning involves having heterogeneous clients such as mobile phones, computers, or IoT (Internet of Things) devices. These devices may differ in terms of hardware, software, computation capabilities, and data types.
HeteroFL was developed in response to the common Federated Learning strategies that assume the local models’ attributes resemble those of the main model. But in the real world, it happens very rarely. HeteroFL can generate a single global model for inference after training over multiple varied local models.
Federated learning algorithms
Federated stochastic gradient descent (FedSGD)
In traditional SGD, the gradients are computed on mini-batches, which are a fraction of data samples obtained from the total samples. In the federated setting, these mini-batches can be considered different client devices that comprise local data.
In FedSGD, the central model is distributed to the clients, and each client computes the gradients using local data. These gradients are then passed to the central server, which aggregates the gradients in proportion to the number of samples present on each client to calculate the gradient descent step.
Federated averaging (FedAvg)
Federated averaging is an extension of the FedSGD algorithm.
Clients can perform more than one local gradient descent update. Instead of sharing the gradients with the central server, weights tuned on the local model are shared. Finally, the server aggregates the clients' weights (model parameters).
Federated Averaging is a generalization of FedSGD—if all the clients begin from the same initialization, averaging the gradients is equal to averaging the weights. Therefore, Federated Averaging leaves room for tuning the local weights before sending them to the central server for averaging.
Federated learning with dynamic regularization (FedDyn)
Regularization in traditional machine learning methods aims to add a penalty to the loss function to improve generalization. In federated learning, the global loss must be computed based on local losses generated from heterogeneous devices.
Due to the heterogeneity of clients, minimizing global loss is different than minimizing local losses. Therefore, FedDyn method aims to generate the regularization term for local losses by adapting to the data statistics, such as the amount of data or communication cost. This modification of local losses through dynamic regularization enables local losses to converge to the global loss.
Federated learning frameworks
As the research in computer vision progresses with large-scale Convolutional Neural Networks and dense transformer models, the scarcity of tools and techniques to implement it in the federated setting becomes evident.
The FedCV framework is built to bridge the gap between research and the real-world implementation of federated learning algorithms.
FedCV is a unified library for federated learning to address computer vision applications of image segmentation, image classification, and object detection. It provides access to various datasets and models through easy-to-use APIs. The framework consists of three major modules:
Computer Vision Applications layer
High-level API
Low-level API
Let’s discuss the contributions of all these modules.
FedCV framework (source)
The high-level API consists of models for computer vision tasks of image segmentation, image classification, and object detection. The users can use the existing data loaders and data partitioning schemes. Also, they can create their own non-i.i.d (identical and independent distribution) data, which can test the robustness of federated learning methods (as the real-world data is usually non-i.i.d).
The high-level API also provides implementations of state-of-the-art federated learning algorithms such as FedAvg, FedNAS, and many more. The training can be completed in a reasonable time due to the available support of distributed multi-GPU training. Additionally, the algorithms can be trained using novel distributed computing strategies.
The user-oriented API design enables easy implementation and flexible interactions between clients and workers.
The low-level API consists of enhanced security and privacy primitive modules that allow secure and private communication between servers present at different locations.
Challenges and limitations of federated learning
As any newly-developed technology, federated learning meets with a few crucial challenges. Let’s go through a few examples.
Communication efficiency
Federated learning involves millions of devices in one network. The transfer of messages becomes slow due to several reasons: low bandwidth, lack of resources, or geographical location.
To keep the communication channels efficient, the total number of message passes and the size of a message in a single pass should be reduced. We can achieved it by using
local updating methods (to reduce the number of rounds)
model compression schemes (to reduce the size of the message)
decentralized training (to operate in low bandwidth)
Privacy and data protection
Privacy and data security are some of the biggest concerns with federated learning. Although the local data stays on the user device, there’s a risk for the information to be revealed from the model updates shared in the network.
Some of the common privacy-preserving techniques that can solve this problem include:
Differential privacy—adding noisy data that makes it difficult to discern real information in case of data leaks
Homomorphic encryption—performing computation on encrypted data
Secure multiparty computation—spreading the sensitive data to different data owners so that they can collaboratively perform computation and reduce the risk of privacy breach
Systems heterogeneity
With the large number of devices playing a role in federated learning networks, accounting for differences in storage, communication, and computational capabilities is a huge challenge. Additionally, only a few of these devices participate at a given time, which may lead to biased training.
Such heterogeneities can be handled by the techniques of asynchronous communication, active device sampling, and fault tolerance.
Statistical heterogeneity
This problem is posed by the multiple variations of data present across the client devices.
For example, some devices may have high-resolution image data, while others can only store low-resolution pictures, or languages might vary based on geographical location.
These instances denote that data is non-i.i.d in a federated learning setting, which is in contrast with the assumption of i.i.d data in normal algorithms. This might cause problems in the data structuring, modeling, and inferencing phases.
Real-life applications of federated learning
Federated learning is already present across many different use cases and industries. Let’s go through a few of the most common applications.
Smartphones
Smartphones are one of the most common ways to witness federated learning in action. Word prediction, face recognition for logging, or voice recognition while using Siri or Google Assistant are all examples of federated-learning-based solutions. It helps personalize the user experience while maintaining privacy.
Transportation
Self-driving cars use computer vision and machine learning to analyze the surroundings and interpret the learning in real-time. In order to continuously adapt to the environment, models need to learn from diverse datasets to improve precision.
Relying on a traditional cloud-based approach would slow down the systems. Using federated learning can speed up the learning and make the models more robust.
Read more: 9 Revolutionary AI Applications In Transportation
Road object detection for traffic flow analysis in V7
Manufacturing
Manufacturing usually understands the demand for a product based on personal sales. With federated learning, the product recommendation systems can be improved based on the broader set of data obtained.
AR/VR can be used to detect objects and assist with remote operations and virtual assembly. Federated learning can help improve detection systems to create optimal models.
Another example is using federated learning in industrial environment monitoring. Federated learning makes it easier to perform a time-series analysis of the industrial environment factors obtained using multiple sensors and companies while maintaining confidential data's privacy.
Read more: 7 Out-of-the-Box Applications of AI in Manufacturing
Healthcare
The sensitive nature of healthcare data and its restricted access due to privacy issues make it difficult to scale machine learning systems in this industry globally.
With federated learning, models can be trained through secure access to data from patients and medical institutions while the data remains at its original premises. It can help individual institutions collaborate with others and makes it possible for the models to learn from more datasets securely. This is particularly important for AI document processing in healthcare, where patient records must stay within secure hospital networks. Federated learning lets hospitals collaborate on AI models without moving sensitive documents off-premises.
Additionally, federated learning can allow clinicians to gain insights about patients or diseases from wider demographic areas beyond local institutions and grant smaller rural hospitals access to advanced AI technologies.
Federated learning in healthcare (source)
Read more: 7 Life-Saving AI Use Cases in Healthcare
Federated Learning: Key takeaways
Federated learning (FL) is a decentralized approach to training machine learning models that gives advantages of privacy protection, data security, and access to heterogeneous data over the usual centralized machine learning approaches.
We can obtain more accurate and generalizable models through FL without having the data leave the client devices.
The three main strategies to perform FL are Centralized FL, Decentralized FL, and Heterogeneous FL; some of the popular FL algorithms include FedSGD, FedAvg, and FedDyn.
FedCV is an FL framework built for computer vision applications. It bridges the gap between research and implementation by providing an easy-to-use unified library with multiple functionalities. The practical applications of FedCV can be found across multiple industries, including healthcare, transportation, or manufacturing.
Despite the ongoing research on scaling FL systems, certain limitations still need to be addressed. It’s necessary to improve communication efficiency, protect data privacy, and incorporate the heterogeneity present at systems and statistical levels.
Read next:
Deep Learning 101: Introduction [Pros, Cons, & Uses]
Train Test Validation Split: How To & Best Practices [2023]
The Essential Guide to Neural Network Architectures
Activation Functions in Neural Networks [12 Types & Use Cases]
Supervised vs. Unsupervised Learning [Differences & Examples]
Yesha is a master’s student at Université de Montréal pursuing Computer Science with Machine Learning. She is passionate about exploring AI techniques that benefit society and sharing her learnings with the AI community.