In this lesson, you will build your first federated learning project with Flower and PyTorch. You will learn how federated learning enables you to train AI on distributed data. The model and data you are going to use are just an example to showcase federated learning in action. But remember, this can be extended to most other models and datasets and even different frameworks like TensorFlow, Jax, Hugging Face Transformers, and Apple's MLX. All right, let's build something. In a basic federated learning system. You have a server and you have clients. The server often does not have any data itself. It can have some data used to evaluate the global model. But in vanilla federated learning it doesn't have any training data. The clients are the ones that have the actual training data. If you have a system where five hospitals collaborate on model training, you would have five clients, one for every hospital. Each of those clients would run in one of the hospital environments, and it would have access to the data of that particular hospital. If you had a system where 100 million user devices hold data, you would have 100 million clients, one client on each of the user devices. The role of the server is to coordinate the training across those clients. The role of the client is to do the actual training on their respective local data set. Both the server and the clients have their own copies of the model. The model on the server is often called the global model. The models on a client are often called the local models. So how does the training work across multiple clients? The whole process starts by the server initializing the global model parameters. The server sends the parameters of the global model to the client. In the example here we have five clients a tablet, a desktop, mobile phone, a laptop, and a server. Those five clients then train the model on their local data. They train the model only for a little while, so not until full convergence. Often they train for just a single epoch on their local data set. After the local training, the client sent their improved models back to the server. The server now has five improved models, all of them with slightly different weights. But what you want is one model, not five models to get to one model. The server aggregates those five models. There are different ways to aggregate models, but one of the most common ones is to just average the weights after the first aggregation. You end up with a slightly improved version of the global model. With this new version of the global model, you're going to repeat the steps described before you send a new model to the clients. Clients train on their local training data, they send back to improved models, and the server aggregates those models. Federated learning is an iterative process. It repeats those so-called rounds over and over again until convergence. Here's a slightly more formal description of the federated learning algorithm. We start with initialization. The server initializes the global model. Next is communication round. For each communication round, the server sends the global model to the participating clients, and each client receives the global model. Client training and model update mean that each participating client. The client trains the received model on its local data set, and after that finishes, the client sends its locally updated model back to the server. Step four is model aggregation. The server aggregates the updated models received from all clients using an aggregation algorithm. For instance, this can be federated averaging, which just as a weighted average over all the model updates received from the clients, weighted by the number of training examples that went into training on each particular client. Step five is convergence. Check if convergence criteria are met. We end the federated learning process. If not, we proceed to the next communication round, which was described in step two. You're now going to build your first federated learning project. In the previous lesson, you build three independent data sets and train three independent models on them. In this lesson, you're going to connect those individual models with the goal of training one collaborative model across three distributed datasets. Let's jump into the lab. As before, you will start with some imports. In this lesson, you will use the Flower Federated Learning framework to federate the previously used training pipeline. So apart from utils you also have some Flower related imports. For example you have client app, you have server app and you have fed average which is a federated averaging strategy. The first step is for you to recreate the three datasets used in the previous lesson for training on the MNIST dataset. MNIST is training dataset is loaded with the same transformations as before, to normalize the data set. The data set is then split into three parts part one, part two, and part three, with the same sizes as in the previous lesson. The same random seed is used to ensure reproducibility. Digits one, three, and seven are excluded from part one. Digits two, five, and eight are excluded from part two, and digits four, six, and nine are excluded from part three. For later use, you also put all three training data sets in a list called train underscore sets. The full MNIST test dataset is loaded with the same transformation, and you also create the same three subsets of the test data set, test set 137, Test set 258, and test set 469 as in lesson one. In federated learning, you need to exchange model parameters between server and clients. When the client receives model parameters from the server, it needs to update the local model with those new parameters received from the server. When the client finishes training, it needs to send the latest version of the local model parameters back to the server. To enable this, we need two functions set weights and get weights. Get weights takes one argument, a reference to our simple PyTorch model. It then iterates over the items in state dict, converts each one into a numpy ndarray, and returns a list containing all of those and the ndarrays. You use get weights. After the local model training has finished to get the model's updated weights and send them back to the server. Set weights goes the other direction. It takes two arguments, a reference to our simple PyTorch model and a list of ndarrays. It then uses this list of the ndarrays to update all the items in the model state dict. You use set weights before the local model training to update the model's weights using the new weights received from the server. Note that both functions are model specific. If you use a different model. You might need to adjust set weights and get weights accordingly. To connect your existing training and evaluation pipeline, you write a Flower client. Flower client uses existing functions like train model and evaluate model to enable the Flower framework to orchestrate federated training over a set of participating clients. The Flower client class is defined as a subclass of NumPy client. You pass three arguments to the constructor to initialize a Flower client object. Net, our simple neural net model implemented in PyTorch. Train set, the training dataset of one particular client. Test set the test dataset of one particular client. Flower client typically defines two methods. The fit method to train the neural net using the provided parameters and the local training dataset. The evaluate method evaluates the performance of the neural net using the provided parameters and the local test dataset set to enable the Flower framework to create client objects when necessary. We need to implement a function called client function that creates a Flower client instance on demand. This is necessary to optimize resource utilization. Federated training can easily spend hundreds of clients, but when building sub systems, you want to simulate them efficiently on a single machine. Flower calls client function whenever it needs an instance of one particular client to call, fit, or evaluate. This enables the framework to discard those objects and free up resources when a particular client object isn't needed. Last, we create an instance of client app by passing it a previously defined client function. Client app is the entry point to everything happening on the client side, and you will learn about additional client app features in an upcoming lesson. Now that you have a client app that can perform local training and evaluation, you need a server side counterpart to aggregate the updated models received from the client. Of course, you want to see how the global model performs. For that, you start by creating an evaluation function called evaluate. Evaluate takes three arguments. The current server round. The latest version of the global model parameters as a list of NumPy ndarrays, and the config dict. Similar to Flower client dot evaluate, it updates the model with the latest parameters using set weights. It evaluates the model's performance on multiple test datasets. In our case. Test set. Test set 137. Test set 258, and test set 469. To assess its accuracy on the full MNIST test set, but also on the different subsets of digits. It prints evaluation results, including accuracy on all digits and on the specific subsets 137, 258, 469 for each server around. So we can see how the accuracy evolves over multiple rounds of federated learning. If the current server round is the final round, which is indicated by server underscore round equals three. It computes and plots the confusion matrix for the model's predictions on the entire test dataset. This will allow you to understand how the federated model performs compared to the three individual models in lesson one. To create a server app, you need to decide which strategy you want to use. Strategy is an abstraction that implements a server-side federated learning algorithm. We covered federated averaging before, but there are also many other algorithms like FedAdam FedMedian, and QFedAverage many of them available as built in strategies in Flower. Let's start with plain federated averaging. To initialize federated average you need to pass fraction fit, the fraction of available clients selected for training. Fraction evaluate, the fraction of available clients selected for evaluation Initial parameters the initial model weights and evaluate function the function to use for server side model evaluation. Now that you have a strategy, you can create an instance of server app. With client app and server app, you can finally start the training. A real world federated learning system would be distributed across a number of servers or user devices. In this notebook environment, you simulate such a system by running everything on a single machine. For that, you use a function called run simulation that takes three arguments the server app, the client app, and num super nodes. The number of clients that run simulation should simulate. Flower calls, clients super nodes to emphasize the importance of those nodes in the federated learning process. Compared to a traditional client server setup with a powerful server and thin clients and federated learning clients are the real stars of the show. They have the valuable data and the compute to perform training. Let's run it! You can see that run simulation starts the Flower Server app using a configuration that performs three runs on federated learning, and it tells us that it's using the initial global model parameters provided by the strategy. It then continues to evaluate those initial global model parameters using the evaluation function that you defined earlier. This means, it evaluates the test accuracy on all digits, but also on the three different subsets that you defined earlier. It then continues to go into round one of federated learning. The strategy samples three clients out of a total of three available clients. It samples those clients, sends the global model parameters to those three clients, and asks those clients to perform the training on their local dataset. This will take a minute to finish. You can now see that the server received three results and zero failures. All of three clients submitted their results back to the server. The server then continues to call the evaluation function again to see how the newly aggregated global model performs on the test set on the full test set, but also on the three different test sets that you defined earlier. The server then proceeds to round two. It again. Samples three clients out of a total of three available clients. You can see that again, the server received three results and zero failures. It continues to again called evaluation function. To evaluate the newly aggregated global model, and then proceed to round three. For round three, it again samples three clients out of three available clients. This will again take a minute to complete on the client side. When the three clients finished their local training. They send their updated model back to the server. You can see that the server receives three results and zero failures. It then continues to call evaluate to evaluate the final global model on the full mNIST test set, but also on the three specific sets you built earlier. Remember that in lesson one, when you train three individual models, those models achieved an accuracy of roughly in between 65 and 70%. Here with federated learning. You can see a big jump in accuracy to 96%. This means that the global model is much better than any of the individual models trained in lesson one. Perhaps even more interesting, you can see that on those specific subsets, the global model sees a jump in accuracy from previously 0% to in between 94 and 96%. Looking at the confusion matrix, we see a very different picture compared to the confusion matrix of the individual models trained in lesson one. The model learns how to classify all digits, even if those digits are missing in one of the data sets. So we don't see any columns that only have zeros in them anymore. Let's review lesson two. Federated learning is an iterative process. Clients with data train the model and the server, often without data aggregates. Model updates. You define client side training, evaluation or analytics via the Flower client app, and you define server side configuration or aggregation via the Flower server app. During development, you usually simulate those systems on a single machine, but then for a production setting, you deployed on those different machines that have the individual datasets.