1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for GCEClusterResolver.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.distribute.cluster_resolver.cluster_resolver import UnionClusterResolver 22from tensorflow.python.distribute.cluster_resolver.gce_cluster_resolver import GCEClusterResolver 23from tensorflow.python.platform import test 24from tensorflow.python.training import server_lib 25 26 27mock = test.mock 28 29 30class GCEClusterResolverTest(test.TestCase): 31 32 def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): 33 self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def()) 34 self.assertProtoEquals( 35 expected_proto, server_lib.ClusterSpec(cluster_spec).as_cluster_def()) 36 self.assertProtoEquals( 37 expected_proto, 38 server_lib.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def()) 39 self.assertProtoEquals( 40 expected_proto, 41 server_lib.ClusterSpec(cluster_spec.as_dict()).as_cluster_def()) 42 43 def standard_mock_instance_groups(self, instance_map=None): 44 if instance_map is None: 45 instance_map = [ 46 {'instance': 'https://gce.example.com/res/gce-instance-1'} 47 ] 48 49 mock_instance_group_request = mock.MagicMock() 50 mock_instance_group_request.execute.return_value = { 51 'items': instance_map 52 } 53 54 service_attrs = { 55 'listInstances.return_value': mock_instance_group_request, 56 'listInstances_next.return_value': None, 57 } 58 mock_instance_groups = mock.Mock(**service_attrs) 59 return mock_instance_groups 60 61 def standard_mock_instances(self, instance_to_ip_map=None): 62 if instance_to_ip_map is None: 63 instance_to_ip_map = { 64 'gce-instance-1': '10.123.45.67' 65 } 66 67 mock_get_request = mock.MagicMock() 68 mock_get_request.execute.return_value = { 69 'networkInterfaces': [ 70 {'networkIP': '10.123.45.67'} 71 ] 72 } 73 74 def get_side_effect(project, zone, instance): 75 del project, zone # Unused 76 77 if instance in instance_to_ip_map: 78 mock_get_request = mock.MagicMock() 79 mock_get_request.execute.return_value = { 80 'networkInterfaces': [ 81 {'networkIP': instance_to_ip_map[instance]} 82 ] 83 } 84 return mock_get_request 85 else: 86 raise RuntimeError('Instance %s not found!' % instance) 87 88 service_attrs = { 89 'get.side_effect': get_side_effect, 90 } 91 mock_instances = mock.MagicMock(**service_attrs) 92 return mock_instances 93 94 def standard_mock_service_client( 95 self, 96 mock_instance_groups=None, 97 mock_instances=None): 98 99 if mock_instance_groups is None: 100 mock_instance_groups = self.standard_mock_instance_groups() 101 if mock_instances is None: 102 mock_instances = self.standard_mock_instances() 103 104 mock_client = mock.MagicMock() 105 mock_client.instanceGroups.return_value = mock_instance_groups 106 mock_client.instances.return_value = mock_instances 107 return mock_client 108 109 def gen_standard_mock_service_client(self, instances=None): 110 name_to_ip = {} 111 instance_list = [] 112 for instance in instances: 113 name_to_ip[instance['name']] = instance['ip'] 114 instance_list.append({ 115 'instance': 'https://gce.example.com/gce/res/' + instance['name'] 116 }) 117 118 mock_instance = self.standard_mock_instances(name_to_ip) 119 mock_instance_group = self.standard_mock_instance_groups(instance_list) 120 121 return self.standard_mock_service_client(mock_instance_group, mock_instance) 122 123 def testSimpleSuccessfulRetrieval(self): 124 gce_cluster_resolver = GCEClusterResolver( 125 project='test-project', 126 zone='us-east1-d', 127 instance_group='test-instance-group', 128 port=8470, 129 credentials=None, 130 service=self.standard_mock_service_client()) 131 132 actual_cluster_spec = gce_cluster_resolver.cluster_spec() 133 expected_proto = """ 134 job { name: 'worker' tasks { key: 0 value: '10.123.45.67:8470' } } 135 """ 136 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 137 138 def testMasterRetrieval(self): 139 gce_cluster_resolver = GCEClusterResolver( 140 project='test-project', 141 zone='us-east1-d', 142 instance_group='test-instance-group', 143 task_id=0, 144 port=8470, 145 credentials=None, 146 service=self.standard_mock_service_client()) 147 self.assertEqual(gce_cluster_resolver.master(), 'grpc://10.123.45.67:8470') 148 149 def testMasterRetrievalWithCustomTasks(self): 150 name_to_ip = [ 151 {'name': 'instance1', 'ip': '10.1.2.3'}, 152 {'name': 'instance2', 'ip': '10.2.3.4'}, 153 {'name': 'instance3', 'ip': '10.3.4.5'}, 154 ] 155 156 gce_cluster_resolver = GCEClusterResolver( 157 project='test-project', 158 zone='us-east1-d', 159 instance_group='test-instance-group', 160 port=8470, 161 credentials=None, 162 service=self.gen_standard_mock_service_client(name_to_ip)) 163 164 self.assertEqual( 165 gce_cluster_resolver.master('worker', 2, 'test'), 166 'test://10.3.4.5:8470') 167 168 def testOverrideParameters(self): 169 name_to_ip = [ 170 {'name': 'instance1', 'ip': '10.1.2.3'}, 171 {'name': 'instance2', 'ip': '10.2.3.4'}, 172 {'name': 'instance3', 'ip': '10.3.4.5'}, 173 ] 174 175 gce_cluster_resolver = GCEClusterResolver( 176 project='test-project', 177 zone='us-east1-d', 178 instance_group='test-instance-group', 179 task_type='testworker', 180 port=8470, 181 credentials=None, 182 service=self.gen_standard_mock_service_client(name_to_ip)) 183 184 gce_cluster_resolver.task_id = 1 185 gce_cluster_resolver.rpc_layer = 'test' 186 187 self.assertEqual(gce_cluster_resolver.task_type, 'testworker') 188 self.assertEqual(gce_cluster_resolver.task_id, 1) 189 self.assertEqual(gce_cluster_resolver.rpc_layer, 'test') 190 self.assertEqual(gce_cluster_resolver.master(), 'test://10.2.3.4:8470') 191 192 def testOverrideParametersWithZeroOrEmpty(self): 193 name_to_ip = [ 194 {'name': 'instance1', 'ip': '10.1.2.3'}, 195 {'name': 'instance2', 'ip': '10.2.3.4'}, 196 {'name': 'instance3', 'ip': '10.3.4.5'}, 197 ] 198 199 gce_cluster_resolver = GCEClusterResolver( 200 project='test-project', 201 zone='us-east1-d', 202 instance_group='test-instance-group', 203 task_type='', 204 task_id=1, 205 port=8470, 206 credentials=None, 207 service=self.gen_standard_mock_service_client(name_to_ip)) 208 209 self.assertEqual(gce_cluster_resolver.master( 210 task_type='', task_id=0), 'grpc://10.1.2.3:8470') 211 212 def testCustomJobNameAndPortRetrieval(self): 213 gce_cluster_resolver = GCEClusterResolver( 214 project='test-project', 215 zone='us-east1-d', 216 instance_group='test-instance-group', 217 task_type='custom', 218 port=2222, 219 credentials=None, 220 service=self.standard_mock_service_client()) 221 222 actual_cluster_spec = gce_cluster_resolver.cluster_spec() 223 expected_proto = """ 224 job { name: 'custom' tasks { key: 0 value: '10.123.45.67:2222' } } 225 """ 226 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 227 228 def testMultipleInstancesRetrieval(self): 229 name_to_ip = [ 230 {'name': 'instance1', 'ip': '10.1.2.3'}, 231 {'name': 'instance2', 'ip': '10.2.3.4'}, 232 {'name': 'instance3', 'ip': '10.3.4.5'}, 233 ] 234 235 gce_cluster_resolver = GCEClusterResolver( 236 project='test-project', 237 zone='us-east1-d', 238 instance_group='test-instance-group', 239 port=8470, 240 credentials=None, 241 service=self.gen_standard_mock_service_client(name_to_ip)) 242 243 actual_cluster_spec = gce_cluster_resolver.cluster_spec() 244 expected_proto = """ 245 job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } 246 tasks { key: 1 value: '10.2.3.4:8470' } 247 tasks { key: 2 value: '10.3.4.5:8470' } } 248 """ 249 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 250 251 def testUnionMultipleInstanceRetrieval(self): 252 worker1_name_to_ip = [ 253 {'name': 'instance1', 'ip': '10.1.2.3'}, 254 {'name': 'instance2', 'ip': '10.2.3.4'}, 255 {'name': 'instance3', 'ip': '10.3.4.5'}, 256 ] 257 258 worker2_name_to_ip = [ 259 {'name': 'instance4', 'ip': '10.4.5.6'}, 260 {'name': 'instance5', 'ip': '10.5.6.7'}, 261 {'name': 'instance6', 'ip': '10.6.7.8'}, 262 ] 263 264 ps_name_to_ip = [ 265 {'name': 'ps1', 'ip': '10.100.1.2'}, 266 {'name': 'ps2', 'ip': '10.100.2.3'}, 267 ] 268 269 worker1_gce_cluster_resolver = GCEClusterResolver( 270 project='test-project', 271 zone='us-east1-d', 272 instance_group='test-instance-group', 273 task_type='worker', 274 port=8470, 275 credentials=None, 276 service=self.gen_standard_mock_service_client(worker1_name_to_ip)) 277 278 worker2_gce_cluster_resolver = GCEClusterResolver( 279 project='test-project', 280 zone='us-east1-d', 281 instance_group='test-instance-group', 282 task_type='worker', 283 port=8470, 284 credentials=None, 285 service=self.gen_standard_mock_service_client(worker2_name_to_ip)) 286 287 ps_gce_cluster_resolver = GCEClusterResolver( 288 project='test-project', 289 zone='us-east1-d', 290 instance_group='test-instance-group', 291 task_type='ps', 292 port=2222, 293 credentials=None, 294 service=self.gen_standard_mock_service_client(ps_name_to_ip)) 295 296 union_cluster_resolver = UnionClusterResolver(worker1_gce_cluster_resolver, 297 worker2_gce_cluster_resolver, 298 ps_gce_cluster_resolver) 299 300 actual_cluster_spec = union_cluster_resolver.cluster_spec() 301 expected_proto = """ 302 job { name: 'ps' tasks { key: 0 value: '10.100.1.2:2222' } 303 tasks { key: 1 value: '10.100.2.3:2222' } } 304 job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } 305 tasks { key: 1 value: '10.2.3.4:8470' } 306 tasks { key: 2 value: '10.3.4.5:8470' } 307 tasks { key: 3 value: '10.4.5.6:8470' } 308 tasks { key: 4 value: '10.5.6.7:8470' } 309 tasks { key: 5 value: '10.6.7.8:8470' } } 310 """ 311 self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) 312 313 def testSettingTaskTypeRaiseError(self): 314 name_to_ip = [ 315 { 316 'name': 'instance1', 317 'ip': '10.1.2.3' 318 }, 319 { 320 'name': 'instance2', 321 'ip': '10.2.3.4' 322 }, 323 { 324 'name': 'instance3', 325 'ip': '10.3.4.5' 326 }, 327 ] 328 329 gce_cluster_resolver = GCEClusterResolver( 330 project='test-project', 331 zone='us-east1-d', 332 instance_group='test-instance-group', 333 task_type='testworker', 334 port=8470, 335 credentials=None, 336 service=self.gen_standard_mock_service_client(name_to_ip)) 337 338 with self.assertRaisesRegex( 339 RuntimeError, 'You cannot reset the task_type ' 340 'of the GCEClusterResolver after it has ' 341 'been created.'): 342 gce_cluster_resolver.task_type = 'foobar' 343 344 345if __name__ == '__main__': 346 test.main() 347