Edgify’s Collaborated Method for Distributed Learning, to be Fully Released at NeurIPS this Year!

## Overview

There is a growing interest today in training deep learning models on the edge. Algorithms such as *Federated Averaging *[1] (FedAvg) allow training on devices with high network latency by performing many local gradient steps before communicating their weights. However, the very nature of this setting is such that there is no control over the way the data is distributed on the devices.

consider, for instance, a smart checkout scale at a supermarket that has a camera mounted on it and some processing power. You want each scale to collect images of fruits and vegetables being scaled and to collectively train a Neural Network on the scales to recognize these fruits and vegetables. Such an unconstrained environment would almost always mean that not all edge devices (in this case, scales) will have data from all the classes (in this case, fruits and vegetables). This is commonly referred to as a *Non-IID data distribution*.

Training with FedAvg on Non-IID data, would lead the locally trained models to “forget” the classes for which they have little or no data.

In a recent paper [2], which will appear at NuerIPS’s Federated Learning for Data Privacy and Confidentiality workshop, we present *Federated Curvature* (FedCurv), an algorithm for training with Federated Learning on non-IID data. In this paper, we build on ideas from *Lifelong Learning *to prevent knowledge forgetting in Federated Learning.

**Lifelong and Federated Learning**

In Lifelong Learning, the challenge is to learn task A, and continue on to learn task Busing the same model, but without “forgetting” task A, i.e. without severely hurting the performance on that task. Or in general, to learn tasks A1, A2 … in sequence without forgetting previously-learnt tasks for which samples are not presented anymore.

In the paper — *Elastic Weight Consolidation *(EWC) [3] the authors propose an algorithm for sequentially training a model on new tasks without forgetting old ones.

The idea behind EWC is to prevent forgetting by identifying the coordinates in the network parameters that are the most informative for a learnt task A, and then, while task B is being learned, penalize the learner for changing these parameters. The basic assumption is that deep neural networks are over-parameterized enough, so that there are good chances of finding an optimal solution *B to task B in the neighborhood of previously learned *A. They depict the idea with the following diagram:

In order to “select” the parameters that are important for the previous task, the authors use the diagonal of the Fisher information matrix. This is a matrix whose size is the same as the model parameter tensor, and each entry’s value correlates with the matching model parameter’s “importance”.

The authors enforce the penalty by adding a term to the optimization objective, forcing model parameters that have high Fisher information for task A, to preserve their value while learning task B. This is depicted by the following objective:

his loss adjustment can be extended to multiple tasks by having the penalty term be a sum on all previous tasks.

**Federated Curvature**

For Federated Learning, we adapt the EWC algorithm from a sequential algorithm to a parallel one. In this scenario, we keep communicating and averaging the local models, just like in FedAvg, but we also add the EWC penalty for forcing each local model to preserve the knowledge of all other devices. During communication, each device sends its model and the model’s Fisher information matrix diagonal. Mathematically, we get:

This way, we enable training on local data, without forgetting the knowledge gained from data of other devices (such as other classes).

**Keeping Low Bandwidth and Preserving Privacy**

At first glance, the number of new terms added to the loss would seem to grow linearly with the number of edge devices. However, as we show in [2], by simple arithmetic manipulations we can keep to a constant number of terms which depend on the sum of Fisher information matrices, making the loss function scalable since the number of terms is not dependent on the number of edge devices. This also means that while each edge needs to send the model and its Fisher information matrix diagonal to the central point, the central points only need to send the *aggregation* of the individual models and their Fisher information matrix diagonals to each edge. Note that FedCurv only sends local gradient-related aggregated information (aggregated on local data) to the central point. In terms of privacy, it is not significantly different from the classical FedAvg algorithm.

**Experiments**

We tested FedCurv on a set of 96 edge devices. We used MNIST for the experiment, and divided the data so that every device has images from exactly 2 class (which no other device sees). We compared FedCurv to FedAvg and FedProx [4] (the central existing solution, whose description is beyond the scope of this blog).

Since the main benefit of our algorithm is that it allows less frequent communications, we expect that as the number of local epochs *E* between consecutive communication rounds increases the advantages of using FedCurv will become more apparent, i.e. FedCurv will need less iterations to reach a desired accuracy.

The results, presented in table 1, show that for 50 local epochs FedCurv achieved 90% accuracy three times faster than FedAvg. Figures 1 and 2 show that both FedProx and FedCurv are doing well at the beginning of the training process. However, while FedCurv provides enough flexibility that allows for reaching high accuracy at the end of the process, the stiffness of the parameters in FedProx comes at the expense of accuracy.

**Conclusion**

We presented the problem of non-i.i.d data in Federated Learning. We showed how this is related to the problem forgetting in Life-Long Learning and presented FedCurv, a novel approach to train for training in this case. We showed that FedCurv can be implemented efficiently without a substantial increase in bandwidth.