• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for GCEClusterResolver."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver
22from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver
23from tensorflow.python.platform import test
24from tensorflow.python.training import server_lib
25
26
27mock = test.mock
28
29
30class GCEClusterResolverTest(test.TestCase):
31
32  def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
33    self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
34    self.assertProtoEquals(
35        expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def())
36    self.assertProtoEquals(
37        expected_proto,
38        server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
39    self.assertProtoEquals(
40        expected_proto,
41        server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
42
43  def standard_mock_instance_groups(self, instance_map=None):
44    if instance_map is None:
45      instance_map = [
46          {'instance': 'https://gce.example.com/res/gce-instance-1'}
47      ]
48
49    mock_instance_group_request = mock.MagicMock()
50    mock_instance_group_request.execute.return_value = {
51        'items': instance_map
52    }
53
54    service_attrs = {
55        'listInstances.return_value': mock_instance_group_request,
56        'listInstances_next.return_value': None,
57    }
58    mock_instance_groups = mock.Mock(**service_attrs)
59    return mock_instance_groups
60
61  def standard_mock_instances(self, instance_to_ip_map=None):
62    if instance_to_ip_map is None:
63      instance_to_ip_map = {
64          'gce-instance-1': '10.123.45.67'
65      }
66
67    mock_get_request = mock.MagicMock()
68    mock_get_request.execute.return_value = {
69        'networkInterfaces': [
70            {'networkIP': '10.123.45.67'}
71        ]
72    }
73
74    def get_side_effect(project, zone, instance):
75      del project, zone  # Unused
76
77      if instance in instance_to_ip_map:
78        mock_get_request = mock.MagicMock()
79        mock_get_request.execute.return_value = {
80            'networkInterfaces': [
81                {'networkIP': instance_to_ip_map[instance]}
82            ]
83        }
84        return mock_get_request
85      else:
86        raise RuntimeError('Instance %s not found!' % instance)
87
88    service_attrs = {
89        'get.side_effect': get_side_effect,
90    }
91    mock_instances = mock.MagicMock(**service_attrs)
92    return mock_instances
93
94  def standard_mock_service_client(
95      self,
96      mock_instance_groups=None,
97      mock_instances=None):
98
99    if mock_instance_groups is None:
100      mock_instance_groups = self.standard_mock_instance_groups()
101    if mock_instances is None:
102      mock_instances = self.standard_mock_instances()
103
104    mock_client = mock.MagicMock()
105    mock_client.instanceGroups.return_value = mock_instance_groups
106    mock_client.instances.return_value = mock_instances
107    return mock_client
108
109  def gen_standard_mock_service_client(self, instances=None):
110    name_to_ip = {}
111    instance_list = []
112    for instance in instances:
113      name_to_ip[instance['name']] = instance['ip']
114      instance_list.append({
115          'instance': 'https://gce.example.com/gce/res/' + instance['name']
116      })
117
118    mock_instance = self.standard_mock_instances(name_to_ip)
119    mock_instance_group = self.standard_mock_instance_groups(instance_list)
120
121    return self.standard_mock_service_client(mock_instance_group, mock_instance)
122
123  def testSimpleSuccessfulRetrieval(self):
124    gce_cluster_resolver = GCEClusterResolver(
125        project='test-project',
126        zone='us-east1-d',
127        instance_group='test-instance-group',
128        port=8470,
129        credentials=None,
130        service=self.standard_mock_service_client())
131
132    actual_cluster_spec = gce_cluster_resolver.cluster_spec()
133    expected_proto = """
134    job { name: 'worker' tasks { key: 0 value: '10.123.45.67:8470' } }
135    """
136    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
137
138  def testMasterRetrieval(self):
139    gce_cluster_resolver = GCEClusterResolver(
140        project='test-project',
141        zone='us-east1-d',
142        instance_group='test-instance-group',
143        task_id=0,
144        port=8470,
145        credentials=None,
146        service=self.standard_mock_service_client())
147    self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470')
148
149  def testMasterRetrievalWithCustomTasks(self):
150    name_to_ip = [
151        {'name': 'instance1', 'ip': '10.1.2.3'},
152        {'name': 'instance2', 'ip': '10.2.3.4'},
153        {'name': 'instance3', 'ip': '10.3.4.5'},
154    ]
155
156    gce_cluster_resolver = GCEClusterResolver(
157        project='test-project',
158        zone='us-east1-d',
159        instance_group='test-instance-group',
160        port=8470,
161        credentials=None,
162        service=self.gen_standard_mock_service_client(name_to_ip))
163
164    self.assertEqual(
165        gce_cluster_resolver.master('worker', 2, 'test'),
166        'test://10.3.4.5:8470')
167
168  def testOverrideParameters(self):
169    name_to_ip = [
170        {'name': 'instance1', 'ip': '10.1.2.3'},
171        {'name': 'instance2', 'ip': '10.2.3.4'},
172        {'name': 'instance3', 'ip': '10.3.4.5'},
173    ]
174
175    gce_cluster_resolver = GCEClusterResolver(
176        project='test-project',
177        zone='us-east1-d',
178        instance_group='test-instance-group',
179        task_type='testworker',
180        port=8470,
181        credentials=None,
182        service=self.gen_standard_mock_service_client(name_to_ip))
183
184    gce_cluster_resolver.task_id = 1
185    gce_cluster_resolver.rpc_layer = 'test'
186
187    self.assertEqual(gce_cluster_resolver.task_type, 'testworker')
188    self.assertEqual(gce_cluster_resolver.task_id, 1)
189    self.assertEqual(gce_cluster_resolver.rpc_layer, 'test')
190    self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470')
191
192  def testOverrideParametersWithZeroOrEmpty(self):
193    name_to_ip = [
194        {'name': 'instance1', 'ip': '10.1.2.3'},
195        {'name': 'instance2', 'ip': '10.2.3.4'},
196        {'name': 'instance3', 'ip': '10.3.4.5'},
197    ]
198
199    gce_cluster_resolver = GCEClusterResolver(
200        project='test-project',
201        zone='us-east1-d',
202        instance_group='test-instance-group',
203        task_type='',
204        task_id=1,
205        port=8470,
206        credentials=None,
207        service=self.gen_standard_mock_service_client(name_to_ip))
208
209    self.assertEqual(gce_cluster_resolver.master(
210        task_type='', task_id=0), 'grpc://10.1.2.3:8470')
211
212  def testCustomJobNameAndPortRetrieval(self):
213    gce_cluster_resolver = GCEClusterResolver(
214        project='test-project',
215        zone='us-east1-d',
216        instance_group='test-instance-group',
217        task_type='custom',
218        port=2222,
219        credentials=None,
220        service=self.standard_mock_service_client())
221
222    actual_cluster_spec = gce_cluster_resolver.cluster_spec()
223    expected_proto = """
224    job { name: 'custom' tasks { key: 0 value: '10.123.45.67:2222' } }
225    """
226    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
227
228  def testMultipleInstancesRetrieval(self):
229    name_to_ip = [
230        {'name': 'instance1', 'ip': '10.1.2.3'},
231        {'name': 'instance2', 'ip': '10.2.3.4'},
232        {'name': 'instance3', 'ip': '10.3.4.5'},
233    ]
234
235    gce_cluster_resolver = GCEClusterResolver(
236        project='test-project',
237        zone='us-east1-d',
238        instance_group='test-instance-group',
239        port=8470,
240        credentials=None,
241        service=self.gen_standard_mock_service_client(name_to_ip))
242
243    actual_cluster_spec = gce_cluster_resolver.cluster_spec()
244    expected_proto = """
245    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' }
246                         tasks { key: 1 value: '10.2.3.4:8470' }
247                         tasks { key: 2 value: '10.3.4.5:8470' } }
248    """
249    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
250
251  def testUnionMultipleInstanceRetrieval(self):
252    worker1_name_to_ip = [
253        {'name': 'instance1', 'ip': '10.1.2.3'},
254        {'name': 'instance2', 'ip': '10.2.3.4'},
255        {'name': 'instance3', 'ip': '10.3.4.5'},
256    ]
257
258    worker2_name_to_ip = [
259        {'name': 'instance4', 'ip': '10.4.5.6'},
260        {'name': 'instance5', 'ip': '10.5.6.7'},
261        {'name': 'instance6', 'ip': '10.6.7.8'},
262    ]
263
264    ps_name_to_ip = [
265        {'name': 'ps1', 'ip': '10.100.1.2'},
266        {'name': 'ps2', 'ip': '10.100.2.3'},
267    ]
268
269    worker1_gce_cluster_resolver = GCEClusterResolver(
270        project='test-project',
271        zone='us-east1-d',
272        instance_group='test-instance-group',
273        task_type='worker',
274        port=8470,
275        credentials=None,
276        service=self.gen_standard_mock_service_client(worker1_name_to_ip))
277
278    worker2_gce_cluster_resolver = GCEClusterResolver(
279        project='test-project',
280        zone='us-east1-d',
281        instance_group='test-instance-group',
282        task_type='worker',
283        port=8470,
284        credentials=None,
285        service=self.gen_standard_mock_service_client(worker2_name_to_ip))
286
287    ps_gce_cluster_resolver = GCEClusterResolver(
288        project='test-project',
289        zone='us-east1-d',
290        instance_group='test-instance-group',
291        task_type='ps',
292        port=2222,
293        credentials=None,
294        service=self.gen_standard_mock_service_client(ps_name_to_ip))
295
296    union_cluster_resolver = UnionClusterResolver(worker1_gce_cluster_resolver,
297                                                  worker2_gce_cluster_resolver,
298                                                  ps_gce_cluster_resolver)
299
300    actual_cluster_spec = union_cluster_resolver.cluster_spec()
301    expected_proto = """
302    job { name: 'ps' tasks { key: 0 value: '10.100.1.2:2222' }
303                     tasks { key: 1 value: '10.100.2.3:2222' } }
304    job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' }
305                         tasks { key: 1 value: '10.2.3.4:8470' }
306                         tasks { key: 2 value: '10.3.4.5:8470' }
307                         tasks { key: 3 value: '10.4.5.6:8470' }
308                         tasks { key: 4 value: '10.5.6.7:8470' }
309                         tasks { key: 5 value: '10.6.7.8:8470' } }
310    """
311    self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
312
313  def testSettingTaskTypeRaiseError(self):
314    name_to_ip = [
315        {
316            'name': 'instance1',
317            'ip': '10.1.2.3'
318        },
319        {
320            'name': 'instance2',
321            'ip': '10.2.3.4'
322        },
323        {
324            'name': 'instance3',
325            'ip': '10.3.4.5'
326        },
327    ]
328
329    gce_cluster_resolver = GCEClusterResolver(
330        project='test-project',
331        zone='us-east1-d',
332        instance_group='test-instance-group',
333        task_type='testworker',
334        port=8470,
335        credentials=None,
336        service=self.gen_standard_mock_service_client(name_to_ip))
337
338    with self.assertRaisesRegex(
339        RuntimeError, 'You cannot reset the task_type '
340        'of the GCEClusterResolver after it has '
341        'been created.'):
342      gce_cluster_resolver.task_type = 'foobar'
343
344
345if __name__ == '__main__':
346  test.main()
347