The way we train machine learning models has come a long way in recent years.
A conventional approach was to gather all data at a central server and use it to train the model. But this method, while easy, has raised concerns about data privacy, leaving a lot of valuable but sensitive data inaccessible.
To address this issue, AI models started to shift to a decentralized approach, and a new concept called "federated learning" has emerged.
In this article, we’ll explore federated learning from the ground up, including its most common applications in machine learning.
Here’s what we’ll cover:
Train ML models and solve any computer vision task faster with V7.
Don't start empty-handed. Explore our repository of 500+ open datasets and test-drive V7's tools.
Ready to jump straight into building your AI models? Check out:
Federated learning (often referred to as collaborative learning) is a decentralized approach to training machine learning models. It doesn’t require an exchange of data from client devices to global servers. Instead, the raw data on edge devices is used to train the model locally, increasing data privacy. The final model is formed in a shared manner by aggregating the local updates.
Here’s why federated learning is important.
A generic baseline model is stored at the central server. The copies of this model are shared with the client devices, which then train the models based on the local data they generate. Over time, the models on individual devices become personalized and provide a better user experience.
In the next stage, the updates (model parameters) from the locally trained models are shared with the main model located at the central server using secure aggregation techniques. This model combines and averages different inputs to generate new learnings. Since the data is collected from diverse sources, there is greater scope for the model to become generalizable.
Once the central model has been re-trained on new parameters, it’s shared with the client devices again for the next iteration. With every cycle, the models gather a varied amount of information and improve further without creating privacy breaches.
Let’s explore the commonly used strategies and algorithms used for federated learning.
Centralized federated learning requires a central server. It coordinates the selection of client devices in the beginning and gathers the model updates during training. The communication happens only between the central server and individual edge devices.
While this approach looks straightforward and generates accurate models, the central server poses a bottleneck problem—network failures can halt the complete process.
Decentralized federated learning does not require a central server to coordinate the learning. Instead, the model updates are shared only among the interconnected edge devices. The final model is obtained on an edge device by aggregating the local updates of the connected edge devices.
This approach prevents the possibility of a single-point failure; however, the model's accuracy is completely dependent on the network topology of the edge devices.
Heterogeneous federated learning involves having heterogeneous clients such as mobile phones, computers, or IoT (Internet of Things) devices. These devices may differ in terms of hardware, software, computation capabilities, and data types.
HeteroFL was developed in response to the common Federated Learning strategies that assume the local models’ attributes resemble those of the main model. But in the real world, it happens very rarely. HeteroFL can generate a single global model for inference after training over multiple varied local models.
In traditional SGD, the gradients are computed on mini-batches, which are a fraction of data samples obtained from the total samples. In the federated setting, these mini-batches can be considered different client devices that comprise local data.
In FedSGD, the central model is distributed to the clients, and each client computes the gradients using local data. These gradients are then passed to the central server, which aggregates the gradients in proportion to the number of samples present on each client to calculate the gradient descent step.
Federated averaging is an extension of the FedSGD algorithm.
Clients can perform more than one local gradient descent update. Instead of sharing the gradients with the central server, weights tuned on the local model are shared. Finally, the server aggregates the clients' weights (model parameters).
Federated Averaging is a generalization of FedSGD—if all the clients begin from the same initialization, averaging the gradients is equal to averaging the weights. Therefore, Federated Averaging leaves room for tuning the local weights before sending them to the central server for averaging.
Regularization in traditional machine learning methods aims to add a penalty to the loss function to improve generalization. In federated learning, the global loss must be computed based on local losses generated from heterogeneous devices.
Due to the heterogeneity of clients, minimizing global loss is different than minimizing local losses. Therefore, FedDyn method aims to generate the regularization term for local losses by adapting to the data statistics, such as the amount of data or communication cost. This modification of local losses through dynamic regularization enables local losses to converge to the global loss.
As the research in computer vision progresses with large-scale Convolutional Neural Networks and dense transformer models, the scarcity of tools and techniques to implement it in the federated setting becomes evident.
The FedCV framework is built to bridge the gap between research and the real-world implementation of federated learning algorithms.
FedCV is a unified library for federated learning to address computer vision applications of image segmentation, image classification, and object detection. It provides access to various datasets and models through easy-to-use APIs. The framework consists of three major modules:
Let’s discuss the contributions of all these modules.
The high-level API consists of models for computer vision tasks of image segmentation, image classification, and object detection. The users can use the existing data loaders and data partitioning schemes. Also, they can create their own non-i.i.d (identical and independent distribution) data, which can test the robustness of federated learning methods (as the real-world data is usually non-i.i.d).
The high-level API also provides implementations of state-of-the-art federated learning algorithms such as FedAvg, FedNAS, and many more. The training can be completed in a reasonable time due to the available support of distributed multi-GPU training. Additionally, the algorithms can be trained using novel distributed computing strategies.
The user-oriented API design enables easy implementation and flexible interactions between clients and workers.
The low-level API consists of enhanced security and privacy primitive modules that allow secure and private communication between servers present at different locations.
As any newly-developed technology, federated learning meets with a few crucial challenges. Let’s go through a few examples.
Federated learning involves millions of devices in one network. The transfer of messages becomes slow due to several reasons: low bandwidth, lack of resources, or geographical location.
To keep the communication channels efficient, the total number of message passes and the size of a message in a single pass should be reduced. We can achieved it by using
Privacy and data security are some of the biggest concerns with federated learning. Although the local data stays on the user device, there’s a risk for the information to be revealed from the model updates shared in the network.
Some of the common privacy-preserving techniques that can solve this problem include:
With the large number of devices playing a role in federated learning networks, accounting for differences in storage, communication, and computational capabilities is a huge challenge. Additionally, only a few of these devices participate at a given time, which may lead to biased training.
Such heterogeneities can be handled by the techniques of asynchronous communication, active device sampling, and fault tolerance.
This problem is posed by the multiple variations of data present across the client devices.
For example, some devices may have high-resolution image data, while others can only store low-resolution pictures, or languages might vary based on geographical location.
These instances denote that data is non-i.i.d in a federated learning setting, which is in contrast with the assumption of i.i.d data in normal algorithms. This might cause problems in the data structuring, modeling, and inferencing phases.
Federated learning is already present across many different use cases and industries. Let’s go through a few of the most common applications.
Smartphones are one of the most common ways to witness federated learning in action. Word prediction, face recognition for logging, or voice recognition while using Siri or Google Assistant are all examples of federated-learning-based solutions. It helps personalize the user experience while maintaining privacy.
Self-driving cars use computer vision and machine learning to analyze the surroundings and interpret the learning in real-time. In order to continuously adapt to the environment, models need to learn from diverse datasets to improve precision.
Relying on a traditional cloud-based approach would slow down the systems. Using federated learning can speed up the learning and make the models more robust.
Manufacturing usually understands the demand for a product based on personal sales. With federated learning, the product recommendation systems can be improved based on the broader set of data obtained.
AR/VR can be used to detect objects and assist with remote operations and virtual assembly. Federated learning can help improve detection systems to create optimal models.
Another example is using federated learning in industrial environment monitoring. Federated learning makes it easier to perform a time-series analysis of the industrial environment factors obtained using multiple sensors and companies while maintaining confidential data's privacy.
The sensitive nature of healthcare data and its restricted access due to privacy issues make it difficult to scale machine learning systems in this industry globally.
With federated learning, models can be trained through secure access to data from patients and medical institutions while the data remains at its original premises. It can help individual institutions collaborate with others and makes it possible for the models to learn from more datasets securely.
Additionally, federated learning can allow clinicians to gain insights about patients or diseases from wider demographic areas beyond local institutions and grant smaller rural hospitals access to advanced AI technologies.
Federated learning (FL) is a decentralized approach to training machine learning models that gives advantages of privacy protection, data security, and access to heterogeneous data over the usual centralized machine learning approaches.
We can obtain more accurate and generalizable models through FL without having the data leave the client devices.
The three main strategies to perform FL are Centralized FL, Decentralized FL, and Heterogeneous FL; some of the popular FL algorithms include FedSGD, FedAvg, and FedDyn.
FedCV is an FL framework built for computer vision applications. It bridges the gap between research and implementation by providing an easy-to-use unified library with multiple functionalities. The practical applications of FedCV can be found across multiple industries, including healthcare, transportation, or manufacturing.
Despite the ongoing research on scaling FL systems, certain limitations still need to be addressed. It’s necessary to improve communication efficiency, protect data privacy, and incorporate the heterogeneity present at systems and statistical levels.
“Collecting user feedback and using human-in-the-loop methods for quality control are crucial for improving Al models over time and ensuring their reliability and safety. Capturing data on the inputs, outputs, user actions, and corrections can help filter and refine the dataset for fine-tuning and developing secure ML solutions.”