basic model federated implementation

parent 891abc69
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import time
np.random.seed(0)
client_data, _ = tff.simulation.datasets.emnist.load_data()
first_client_id = client_data.client_ids[0]
first_client_dataset = client_data.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
# This information is also available as a `ClientData` property:
assert client_data.element_type_structure == first_client_dataset.element_spec
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
preprocessed_client_data = client_data.preprocess(preprocess_dataset)
#Notice that we have both reshaped and renamed the elements of the ordered dict.
# first_client_dataset = preprocessed_client_data.create_tf_dataset_for_client(
# first_client_id)
# print(first_client_dataset.element_spec)
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess_dataset(dataset)
return preprocessed.shuffle(buffer_size=5)
preprocessed_and_shuffled = client_data.preprocess(preprocess_and_shuffle)
# The type signature will remain the same, but the batches will be shuffled.
first_client_dataset = preprocessed_and_shuffled.create_tf_dataset_for_client(
first_client_id)
print(first_client_dataset.element_spec)
###################################################
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
])
return tff.learning.from_keras_model(
model,
# Note: input spec is the _batched_ shape, and includes the
# label tensor which will be passed to the loss function. This model is
# therefore configured to accept data _after_ it has been preprocessed.
input_spec= first_client_dataset.element_spec,
# collections.OrderedDict(
# x=tf.TensorSpec(shape=[None, 784], dtype=tf.float32),
# y=tf.TensorSpec(shape=[None, 1], dtype=tf.int64)),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.01))
selected_client_ids = preprocessed_and_shuffled.client_ids[:10]
print(str(selected_client_ids))
preprocessed_data_for_clients = [
preprocessed_and_shuffled.create_tf_dataset_for_client(
selected_client_ids[i]) for i in range(10)
]
state = trainer.initialize()
for _ in range(5):
t1 = time.time()
state, metrics = trainer.next(state, preprocessed_data_for_clients)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
print('dataset computation without preprocessing:')
#print(client_data.dataset_computation.type_signature)
print('\n')
print('dataset computation with preprocessing:')
print(preprocessed_and_shuffled.dataset_computation.type_signature)
trainer_accepting_ids = tff.simulation.compose_dataset_computation_with_iterative_process(
preprocessed_and_shuffled.dataset_computation, trainer)
for _ in range(5):
t1 = time.time()
state, metrics = trainer_accepting_ids.next(state, selected_client_ids)
t2 = time.time()
print('loss {}, round time {}'.format(metrics['train']['loss'], t2 - t1))
#! /bin/bash
#source venv/bin/activate
python processing/main_processing.py >> implementationLog
Submitted batch job 1721
Submitted batch job 1722
Submitted batch job 1723
Submitted batch job 1724
Submitted batch job 1725
Submitted batch job 1726
Submitted batch job 1727
Submitted batch job 1728
Submitted batch job 1729
Submitted batch job 1730
Submitted batch job 1731
Submitted batch job 1732
Submitted batch job 1733
Submitted batch job 1734
Submitted batch job 1735
Submitted batch job 1736
Submitted batch job 1737
Submitted batch job 1738
Submitted batch job 1739
Submitted batch job 1740
Submitted batch job 1741
Submitted batch job 1742
Submitted batch job 1743
Submitted batch job 1744
Submitted batch job 1745
Submitted batch job 1746
Submitted batch job 1747
Submitted batch job 1748
Submitted batch job 1749
Submitted batch job 1750
Submitted batch job 1751
Submitted batch job 1752
Submitted batch job 1753
Submitted batch job 1754
Submitted batch job 1755
Submitted batch job 1756
Submitted batch job 1757
Submitted batch job 1758
Submitted batch job 1759
Submitted batch job 1760
Submitted batch job 1761
Submitted batch job 1762
Submitted batch job 1763
Submitted batch job 1764
Submitted batch job 1765
Submitted batch job 1766
Submitted batch job 1767
Submitted batch job 1768
Submitted batch job 1769
Submitted batch job 1777
Submitted batch job 1778
Submitted batch job 1779
Submitted batch job 1770
Submitted batch job 1771
Submitted batch job 1772
Submitted batch job 1773
Submitted batch job 1774
Submitted batch job 1775
Submitted batch job 1776
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from models import MnistModel
def evaluate(server_state):
keras_model = MnistModel()
keras_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
keras_model.set_weights(server_state)
keras_model.evaluate(central_emnist_test)
def evaluate2(server_state,tff_learning_model):
#TODO: Assign weights to the model
#First idea = server_update function??
evaluation = tff.learning.build_federated_evaluation(tff_learning_model)
# keras_model = MnistModel()
# keras_model.compile(
# loss=tf.keras.losses.SparseCategoricalCrossentropy(),
# metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
# )
# keras_model.set_weights(server_state)
# keras_model.evaluate(central_emnist_test)
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from models import MnistModel
#tf_dataset_type = None
#model_weights_type = None
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
"""Performs training (using the server model weights) on the client's dataset."""
# Initialize the client model with the current server weights.
client_weights = model.trainable_variables
# Assign the server weights to the client model.
tf.nest.map_structure(lambda x, y: x.assign(y),
client_weights, server_weights)
# Use the client_optimizer to update the local model.
for batch in dataset:
with tf.GradientTape() as tape:
# Compute a forward pass on the batch of data
outputs = model.forward_pass(batch)
# Compute the corresponding gradient
grads = tape.gradient(outputs.loss, client_weights)
grads_and_vars = zip(grads, client_weights)
# Apply the gradient using a client optimizer.
client_optimizer.apply_gradients(grads_and_vars)
return client_weights
@tf.function
def server_update(model, mean_client_weights):
"""Updates the server model weights as the average of the client model weights."""
model_weights = model.trainable_variables
# Assign the mean client weights to the server model.
tf.nest.map_structure(lambda x, y: x.assign(y),
model_weights, mean_client_weights)
return model_weights
#Creating the initialization computation
@tff.tf_computation
def server_init():
model = MnistModel() #model_fn()
return model.trainable_variables
@tff.federated_computation
def initialize_fn():
return tff.federated_value(server_init(), tff.SERVER)
whimsy_model = MnistModel()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
model_weights_type = server_init.type_signature.result
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
model = MnistModel()#model_fn()
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
return client_update(model, tf_dataset, server_weights, client_optimizer)
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
model = MnistModel()#model_fn()
return server_update(model, mean_client_weights)
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
# Broadcast the server weights to the clients.
print("server_weights")
print(str(server_weights.type_signature))
server_weights_at_client = tff.federated_broadcast(server_weights)
# Each client computes their updated weights.
client_weights = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
# The server averages these updates.
mean_client_weights = tff.federated_mean(client_weights)
print("mean_client_wieghts")
print(str(mean_client_weights.type_signature))
# The server updates its model.
server_weights = tff.federated_map(server_update_fn, mean_client_weights)
return server_weights
def get_federated_algorithm():
#Creating the next_fn
#Getting the data type, needed explicitily in the modell functions
whimsy_model = MnistModel() #model_fn()
global tf_dataset_type
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)
print("tf_dataset_type")
print(str(tf_dataset_type))
global model_weights_type
model_weights_type = server_init.type_signature.result
print("model_weights_type")
print(str(model_weights_type))
# finished printing types
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
federated_algorithm = tff.templates.IterativeProcess(
initialize_fn=initialize_fn,
next_fn=next_fn
)
return federated_algorithm
def merge_2_states(server_state1, server_state2):
return np.mean( np.array([ server_state, server2_state ]), axis=0 )
\ No newline at end of file
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from preprocessing import get_client_ids
from preprocessing import get_federated_train_data
from preprocessing import preprocess
from models import MnistModel
from federated_training_algorithm import get_federated_algorithm
from federated_training_algorithm import merge_2_states
from evaluation import evaluate
np.random.seed(0)
print("## Starting...")
#GET THE DATA (for now it's the default dataset)
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
######client_ids = get_client_ids(emnist_train) not used
print("## Preprocessing the federated_train_data")
federated_train_data = get_federated_train_data(emnist_train)
print("## Declaring the model")
#it is done in models.py
print("## Declaring the federated algorithm")
federated_algorithm = get_federated_algorithm()
server_state = federated_algorithm.initialize()
for round in range(20):
server_state = federated_algorithm.next(server_state, federated_train_data)
print("server_state type")
print(str(type(server_state)))
print(str(type(server_state[0])))
print("FINISHEEED")
server2_state = federated_algorithm.initialize()
for round in range(2):
server2_state = federated_algorithm.next(server2_state, federated_train_data)
print("server_state[1]")
print(server_state[1])
print("server2_state[1]")
print(server2_state[1])
print("merged_state[1]")
print(merged_state[1])
# print("federated_algorithm.initialize.type_signature")
# print(str(federated_algorithm.initialize.type_signature)
# print("federated_algorithm.next.type_signature")
# print(str(federated_algorithm.next.type_signature))
# print("## Training the model")
# iterative_process = tff.learning.build_federated_averaging_process(
# MnistModel,
# client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
# state = iterative_process.initialize()
# for round_num in range(1, 11):
# state, metrics = iterative_process.next(state, federated_train_data)
# print('round {:2d}, metrics={}'.format(round_num, metrics))
#evaluation = tff.learning.build_federated_evaluation(MnistModel)
#TODO integration
print("## Evaluation of the model")
\ No newline at end of file
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import collections
input_spec_data_global = None
MnistVariables = collections.namedtuple(
'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')
def create_mnist_variables():
return MnistVariables(
weights=tf.Variable(
lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
name='weights',
trainable=True),
bias=tf.Variable(
lambda: tf.zeros(dtype=tf.float32, shape=(10)),
name='bias',
trainable=True),
num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))
def mnist_forward_pass(variables, batch):
y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
predictions = tf.cast(tf.argmax(y, 1), tf.int32)
flat_labels = tf.reshape(batch['y'], [-1])
loss = -tf.reduce_mean(
tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(predictions, flat_labels), tf.float32))
num_examples = tf.cast(tf.size(batch['y']), tf.float32)
variables.num_examples.assign_add(num_examples)
variables.loss_sum.assign_add(loss * num_examples)
variables.accuracy_sum.assign_add(accuracy * num_examples)
return loss, predictions
def get_local_mnist_metrics(variables):
return collections.OrderedDict(
num_examples=variables.num_examples,
loss=variables.loss_sum / variables.num_examples,
accuracy=variables.accuracy_sum / variables.num_examples)
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
return collections.OrderedDict(
num_examples=tff.federated_sum(metrics.num_examples),
loss=tff.federated_mean(metrics.loss, metrics.num_examples),
accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
class MnistModel(tff.learning.Model):
def __init__(self):
self._variables = create_mnist_variables()
@property
def trainable_variables(self):
return [self._variables.weights, self._variables.bias]
@property
def non_trainable_variables(self):
return []
@property
def local_variables(self):
return [
self._variables.num_examples, self._variables.loss_sum,
self._variables.accuracy_sum
]
@property
def input_spec(self):
return collections.OrderedDict(
x=tf.TensorSpec([None, 784], tf.float32),
y=tf.TensorSpec([None, 1], tf.int32))
@tf.function
def forward_pass(self, batch, training=True):
del training
loss, predictions = mnist_forward_pass(self._variables, batch)
num_exmaples = tf.shape(batch['x'])[0]
return tff.learning.BatchOutput(
loss=loss, predictions=predictions, num_examples=num_exmaples)
@tf.function
def report_local_outputs(self):
return get_local_mnist_metrics(self._variables)
@property
def federated_output_computation(self):
return aggregate_mnist_metrics_across_clients
\ No newline at end of file
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
np.random.seed(0)
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
# NUM_CLIENTS = 10
# BATCH_SIZE = 20
def preprocess(dataset, BATCH_SIZE = 20):
def batch_format_fn(element):
"""Flatten a batch of EMNIST data and return a (features, label) tuple."""
return (tf.reshape(element['pixels'], [-1, 784]),
tf.reshape(element['label'], [-1, 1]))
return dataset.batch(BATCH_SIZE).map(batch_format_fn)
def get_client_ids(emnist_train, NUM_CLIENTS = 10):
return np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)
#client_ids = np.random.choice(emnist_train.client_ids, size=NUM_CLIENTS, replace=False)
def get_federated_train_data(emnist_train, NUM_CLIENTS = 10):
client_ids = get_client_ids(emnist_train, NUM_CLIENTS)
return [preprocess(emnist_train.create_tf_dataset_for_client(x))
for x in client_ids
]
# federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
# for x in client_ids
# ]
##################################################################
## Second Dataset
############################################
def preprocess_and_shuffle(dataset):
"""Applies `preprocess_dataset` above and shuffles the result."""
preprocessed = preprocess(dataset)
return preprocessed.shuffle(buffer_size=5)
\ No newline at end of file
#! /bin/bash
#source venv/bin/activate
python prototypeImplementation.py >> prototypeLog
#global imports
from db.repository import Repository
from db.table_repository import TableRepository
from db.use_case_repository import UseCaseRepository
from db.entities.layer_adapter import LayerAdapter
from db.entities.table import Table
from typing import List
class LayerAdapterService:
_table_repository = TableRepository()
_layer_repository = Repository()
_use_case_repository = UseCaseRepository()
@staticmethod
def check_layer(layer: LayerAdapter):
'''
checks if the given layer has correct mappings regarding the schema of the use_case
'''
# TODO implement with tables
# schema = LayerAdapterService._schema_repository.put(layer.use_case)
# for p in layer.properties:
# if p not in schema.mappings:
# raise ValueError(f'{p} is not existent in the schema!')
@staticmethod
def add_complete(layer: LayerAdapter):
'''
Add a new layer to the DB. Attribute mappings and cluster attributes of the given layer
are used. Before inserting, the layer gets checked for consistency with the schema.
@params:
layer - Required : layer object holding correct data
'''
LayerAdapterService.check_layer(layer)
LayerAdapterService._layer_repository.add(layer)
@staticmethod
def delete_all_use_cases():
# TODO
LayerAdapterService._layer_repository.delete_all_use_cases()
LayerAdapterService._table_repository.delete_all()
LayerAdapterService._use_case_repository.delete_all()
\ No newline at end of file
2021-05-04 14:08:53.127296: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-05-04 14:09:04.981860: W tensorflow/stream_executor/platform/default/dso_loader.cc:59] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64
2021-05-04 14:09:04.982134: W tensorflow/stream_executor/cuda/cuda_driver.cc:312] failed call to cuInit: UNKNOWN ERROR (303)
2021-05-04 14:09:04.982158: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (mcore2): /proc/driver/nvidia/version does not exist
2021-05-04 14:09:04.982428: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-05-04 14:09:05.002724: I tensorflow/core/platform/profile_utils/cpu_utils.cc:104] CPU Frequency: 2599705000 Hz
2021-05-04 14:09:05.004362: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x5583288d82c0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2021-05-04 14:09:05.004381: I tensorflow/compiler/xla/service/service.cc:176] StreamExecutor device (0): Host, Default Version
Traceback (most recent call last):
File "processing/main_processing.py", line 36, in <module>
print(str(server_state.type_signature))
AttributeError: 'list' object has no attribute 'type_signature'
import collections
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
#### check if needed vvv
from matplotlib import pyplot as plt
################################
np.random.seed(0)
print(tff.federated_computation(lambda: 'Hello, World!')())
###########################################################
####### LOADING DATASET ###################################
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
###########################################################
####### Preprocessing data ################################
NUM_CLIENTS = 5
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10
def preprocess(dataset):
def batch_format_fn(element):
"""Flatten a batch `pixels` and return the features as an `OrderedDict`."""
return collections.OrderedDict(
x=tf.reshape(element['pixels'], [-1, 784]),
y=tf.reshape(element['label'], [-1, 1]))
return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)
preprocessed_example_dataset = preprocess(example_dataset)
sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
next(iter(preprocessed_example_dataset)))
print("sample_batch")
print(sample_batch)
def make_federated_data(client_data, client_ids):
return [
preprocess(client_data.create_tf_dataset_for_client(x))
for x in client_ids
]
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
federated_train_data = make_federated_data(emnist_train, sample_clients)
print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
#######################################################################################
###### Creating the tutorial model
#######################################################################################
# def create_keras_model():
# return tf.keras.models.Sequential([
# tf.keras.layers.InputLayer(input_shape=(784,)),
# tf.keras.layers.Dense(10, kernel_initializer='zeros'),
# tf.keras.layers.Softmax(),
# ])
# def model_fn():
# # We _must_ create a new model here, and _not_ capture it from an external
# # scope. TFF will call this within different graph contexts.
# keras_model = create_keras_model()
# return tff.learning.from_keras_model(
# keras_model,
# input_spec=preprocessed_example_dataset.element_spec,
# loss=tf.keras.losses.SparseCategoricalCrossentropy(),
# metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
# #######################################################################################
# ####Tutorial implementation
# iterative_process = tff.learning.build_federated_averaging_process(
# model_fn,
# client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
# server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
# print("\nInitialize Iterative Procees signature:")
# print(str(iterative_process.initialize.type_signature)) ## Just a print for info
# state = iterative_process.initialize()
# NUM_ROUNDS = 10
# print("\nStarting {} rounds of training:".format(NUM_ROUNDS))
# for round_num in range(1, NUM_ROUNDS+1):
# state, metrics = iterative_process.next(state, federated_train_data)
# print('round {:2d}, metrics={}'.format(round_num, metrics))
#######################################################################################
####### Customizing the model implementation ##########################################
#######################################################################################
print(" STAAARTIIING")
MnistVariables = collections.namedtuple(
'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')
def create_mnist_variables():
return MnistVariables(
weights=tf.Variable(
lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
name='weights',
trainable=True),
bias=tf.Variable(
lambda: tf.zeros(dtype=tf.float32, shape=(10)),
name='bias',
trainable=True),
num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))
def mnist_forward_pass(variables, batch):
y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
predictions = tf.cast(tf.argmax(y, 1), tf.int32)
flat_labels = tf.reshape(batch['y'], [-1])
loss = -tf.reduce_mean(
tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
accuracy = tf.reduce_mean(
tf.cast(tf.equal(predictions, flat_labels), tf.float32))
num_examples = tf.cast(tf.size(batch['y']), tf.float32)
variables.num_examples.assign_add(num_examples)
variables.loss_sum.assign_add(loss * num_examples)
variables.accuracy_sum.assign_add(accuracy * num_examples)
return loss, predictions
def get_local_mnist_metrics(variables):
return collections.OrderedDict(
num_examples=variables.num_examples,
loss=variables.loss_sum / variables.num_examples,
accuracy=variables.accuracy_sum / variables.num_examples)
@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
return collections.OrderedDict(
num_examples=tff.federated_sum(metrics.num_examples),
loss=tff.federated_mean(metrics.loss, metrics.num_examples),
accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))
class MnistModel(tff.learning.Model):
def __init__(self):
self._variables = create_mnist_variables()
@property
def trainable_variables(self):
return [self._variables.weights, self._variables.bias]
@property
def non_trainable_variables(self):
return []
@property
def local_variables(self):
return [
self._variables.num_examples, self._variables.loss_sum,
self._variables.accuracy_sum
]
@property
def input_spec(self):
return collections.OrderedDict(
x=tf.TensorSpec([None, 784], tf.float32),
y=tf.TensorSpec([None, 1], tf.int32))
@tf.function
def forward_pass(self, batch, training=True):
del training
loss, predictions = mnist_forward_pass(self._variables, batch)
num_exmaples = tf.shape(batch['x'])[0]
return tff.learning.BatchOutput(
loss=loss, predictions=predictions, num_examples=num_exmaples)
@tf.function
def report_local_outputs(self):
return get_local_mnist_metrics(self._variables)
@property
def federated_output_computation(self):
return aggregate_mnist_metrics_across_clients
iterative_process = tff.learning.build_federated_averaging_process(
MnistModel,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
NUM_ROUNDS = 10
print("\nStarting {} rounds of training:".format(NUM_ROUNDS))
for round_num in range(1, NUM_ROUNDS + 1):
state, metrics = iterative_process.next(state, federated_train_data)
print('round {:2d}, metrics={}'.format(round_num, metrics))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment