1# Copyright 2022 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for aggregation_protocols.""" 15 16import tempfile 17from typing import Any 18from unittest import mock 19 20from absl.testing import absltest 21import tensorflow as tf 22 23from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2 24from fcp.aggregation.protocol import configuration_pb2 25from fcp.aggregation.protocol.python import aggregation_protocol 26from fcp.aggregation.tensorflow.python import aggregation_protocols 27from pybind11_abseil import status 28 29 30def create_client_input(tensors: dict[str, Any]) -> apm_pb2.ClientMessage: 31 with tempfile.NamedTemporaryFile() as tmpfile: 32 tf.raw_ops.Save( 33 filename=tmpfile.name, 34 tensor_names=list(tensors.keys()), 35 data=list(tensors.values())) 36 with open(tmpfile.name, 'rb') as f: 37 return apm_pb2.ClientMessage( 38 simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation( 39 input=apm_pb2.ClientResource(inline_bytes=f.read()))) 40 41 42class CallbackProxy(aggregation_protocol.AggregationProtocol.Callback): 43 """A pass-through Callback that delegates to another Callback. 44 45 This works around the issue that mock.Mock objects aren't recognized as 46 Callback subclasses by pybind11. 47 """ 48 49 def __init__(self, 50 callback: aggregation_protocol.AggregationProtocol.Callback): 51 super().__init__() 52 self._callback = callback 53 54 def OnAcceptClients(self, start_client_id: int, num_clients: int, 55 message: apm_pb2.AcceptanceMessage): 56 self._callback.OnAcceptClients(start_client_id, num_clients, message) 57 58 def OnSendServerMessage(self, client_id: int, message: apm_pb2.ServerMessage): 59 self._callback.OnSendServerMessage(client_id, message) 60 61 def OnCloseClient(self, client_id: int, diagnostic_status: status.Status): 62 self._callback.OnCloseClient(client_id, diagnostic_status) 63 64 def OnComplete(self, result: bytes): 65 self._callback.OnComplete(result) 66 67 def OnAbort(self, diagnostic_status: status.Status): 68 self._callback.OnAbort(diagnostic_status) 69 70 71class AggregationProtocolsTest(absltest.TestCase): 72 73 def test_simple_aggregation_protocol(self): 74 input_tensor = tf.TensorSpec((), tf.int32, 'in') 75 output_tensor = tf.TensorSpec((), tf.int32, 'out') 76 config = configuration_pb2.Configuration(aggregation_configs=[ 77 configuration_pb2.Configuration.ServerAggregationConfig( 78 intrinsic_uri='federated_sum', 79 intrinsic_args=[ 80 configuration_pb2.Configuration.ServerAggregationConfig. 81 IntrinsicArg(input_tensor=input_tensor.experimental_as_proto()), 82 ], 83 output_tensors=[output_tensor.experimental_as_proto()], 84 ), 85 ]) 86 callback = mock.create_autospec( 87 aggregation_protocol.AggregationProtocol.Callback, instance=True) 88 89 agg_protocol = aggregation_protocols.create_simple_aggregation_protocol( 90 config, CallbackProxy(callback)) 91 self.assertIsNotNone(agg_protocol) 92 93 agg_protocol.Start(2) 94 callback.OnAcceptClients.assert_called_once_with(mock.ANY, 2, mock.ANY) 95 start_client_id = callback.OnAcceptClients.call_args.args[0] 96 97 agg_protocol.ReceiveClientMessage( 98 start_client_id, create_client_input({input_tensor.name: 3})) 99 agg_protocol.ReceiveClientMessage( 100 start_client_id + 1, create_client_input({input_tensor.name: 5})) 101 callback.OnCloseClient.assert_has_calls([ 102 mock.call(start_client_id, status.Status.OkStatus()), 103 mock.call(start_client_id + 1, status.Status.OkStatus()), 104 ]) 105 106 agg_protocol.Complete() 107 callback.OnComplete.assert_called_once() 108 with tempfile.NamedTemporaryFile('wb') as tmpfile: 109 tmpfile.write(callback.OnComplete.call_args.args[0]) 110 tmpfile.flush() 111 self.assertEqual( 112 tf.raw_ops.Restore( 113 file_pattern=tmpfile.name, 114 tensor_name=output_tensor.name, 115 dt=output_tensor.dtype), 8) 116 117 118if __name__ == '__main__': 119 absltest.main() 120