• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2022 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for aggregation_protocols."""
15
16import tempfile
17from typing import Any
18from unittest import mock
19
20from absl.testing import absltest
21import tensorflow as tf
22
23from fcp.aggregation.protocol import aggregation_protocol_messages_pb2 as apm_pb2
24from fcp.aggregation.protocol import configuration_pb2
25from fcp.aggregation.protocol.python import aggregation_protocol
26from fcp.aggregation.tensorflow.python import aggregation_protocols
27from pybind11_abseil import status
28
29
30def create_client_input(tensors: dict[str, Any]) -> apm_pb2.ClientMessage:
31  with tempfile.NamedTemporaryFile() as tmpfile:
32    tf.raw_ops.Save(
33        filename=tmpfile.name,
34        tensor_names=list(tensors.keys()),
35        data=list(tensors.values()))
36    with open(tmpfile.name, 'rb') as f:
37      return apm_pb2.ClientMessage(
38          simple_aggregation=apm_pb2.ClientMessage.SimpleAggregation(
39              input=apm_pb2.ClientResource(inline_bytes=f.read())))
40
41
42class CallbackProxy(aggregation_protocol.AggregationProtocol.Callback):
43  """A pass-through Callback that delegates to another Callback.
44
45  This works around the issue that mock.Mock objects aren't recognized as
46  Callback subclasses by pybind11.
47  """
48
49  def __init__(self,
50               callback: aggregation_protocol.AggregationProtocol.Callback):
51    super().__init__()
52    self._callback = callback
53
54  def OnAcceptClients(self, start_client_id: int, num_clients: int,
55                      message: apm_pb2.AcceptanceMessage):
56    self._callback.OnAcceptClients(start_client_id, num_clients, message)
57
58  def OnSendServerMessage(self, client_id: int, message: apm_pb2.ServerMessage):
59    self._callback.OnSendServerMessage(client_id, message)
60
61  def OnCloseClient(self, client_id: int, diagnostic_status: status.Status):
62    self._callback.OnCloseClient(client_id, diagnostic_status)
63
64  def OnComplete(self, result: bytes):
65    self._callback.OnComplete(result)
66
67  def OnAbort(self, diagnostic_status: status.Status):
68    self._callback.OnAbort(diagnostic_status)
69
70
71class AggregationProtocolsTest(absltest.TestCase):
72
73  def test_simple_aggregation_protocol(self):
74    input_tensor = tf.TensorSpec((), tf.int32, 'in')
75    output_tensor = tf.TensorSpec((), tf.int32, 'out')
76    config = configuration_pb2.Configuration(aggregation_configs=[
77        configuration_pb2.Configuration.ServerAggregationConfig(
78            intrinsic_uri='federated_sum',
79            intrinsic_args=[
80                configuration_pb2.Configuration.ServerAggregationConfig.
81                IntrinsicArg(input_tensor=input_tensor.experimental_as_proto()),
82            ],
83            output_tensors=[output_tensor.experimental_as_proto()],
84        ),
85    ])
86    callback = mock.create_autospec(
87        aggregation_protocol.AggregationProtocol.Callback, instance=True)
88
89    agg_protocol = aggregation_protocols.create_simple_aggregation_protocol(
90        config, CallbackProxy(callback))
91    self.assertIsNotNone(agg_protocol)
92
93    agg_protocol.Start(2)
94    callback.OnAcceptClients.assert_called_once_with(mock.ANY, 2, mock.ANY)
95    start_client_id = callback.OnAcceptClients.call_args.args[0]
96
97    agg_protocol.ReceiveClientMessage(
98        start_client_id, create_client_input({input_tensor.name: 3}))
99    agg_protocol.ReceiveClientMessage(
100        start_client_id + 1, create_client_input({input_tensor.name: 5}))
101    callback.OnCloseClient.assert_has_calls([
102        mock.call(start_client_id, status.Status.OkStatus()),
103        mock.call(start_client_id + 1, status.Status.OkStatus()),
104    ])
105
106    agg_protocol.Complete()
107    callback.OnComplete.assert_called_once()
108    with tempfile.NamedTemporaryFile('wb') as tmpfile:
109      tmpfile.write(callback.OnComplete.call_args.args[0])
110      tmpfile.flush()
111      self.assertEqual(
112          tf.raw_ops.Restore(
113              file_pattern=tmpfile.name,
114              tensor_name=output_tensor.name,
115              dt=output_tensor.dtype), 8)
116
117
118if __name__ == '__main__':
119  absltest.main()
120