1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for cloud tpu client.""" 16 17import datetime 18import json 19import os 20import time 21import urllib 22 23from absl import flags 24 25from tensorflow.python.platform import test 26from tensorflow.python.tpu.client import client 27 28FLAGS = flags.FLAGS 29 30mock = test.mock 31 32_UTCNOW_STR = '2000-01-01T00:30:00' 33 34 35def mock_utcnow(): 36 return datetime.datetime.strptime(_UTCNOW_STR, '%Y-%m-%dT%H:%M:%S') 37 38 39def mock_request_compute_metadata(path): 40 if path == 'project/project-id': 41 return 'test-project' 42 elif path == 'instance/zone': 43 return 'projects/test-project/locations/us-central1-c' 44 elif path == 'instance/network-interfaces/0/ip': 45 return '10.128.1.2' 46 return '' 47 48 49class MockRequestClass: 50 51 def __init__(self, name, tpu_map): 52 self._name = name 53 self._tpu_map = tpu_map 54 55 def execute(self): 56 if self._name in self._tpu_map: 57 tpu_dict = self._tpu_map[self._name].copy() 58 if isinstance(tpu_dict.get('health'), list): 59 # Do extraction of health list to a single health string based on time. 60 time_now = time.time() 61 health_now = tpu_dict.get('health')[time_now] 62 tpu_dict['health'] = health_now 63 return tpu_dict 64 else: 65 raise KeyError('Resource %s was not found' % self._name) 66 67 68class MockNodeClass: 69 70 def __init__(self, tpu_map): 71 self._tpu_map = tpu_map 72 73 def get(self, name): 74 return MockRequestClass(name, self._tpu_map) 75 76 77class CloudTpuClientTest(test.TestCase): 78 79 def setUp(self): 80 super().setUp() 81 if 'TPU_API_DISCOVERY_URL' in os.environ: 82 del os.environ['TPU_API_DISCOVERY_URL'] 83 if 'TPU_NAME' in os.environ: 84 del os.environ['TPU_NAME'] 85 self._time_now = 0 86 self.addCleanup(mock.patch.stopall) 87 88 def _mock_time(self, *args, **kwargs): 89 return self._time_now 90 91 def _mock_sleep(self, secs): 92 self._time_now += secs 93 94 def mock_service_client(self, tpu_map=None): 95 if tpu_map is None: 96 tpu_map = {} 97 98 mock_locations = mock.MagicMock() 99 mock_locations.nodes.return_value = MockNodeClass(tpu_map) 100 101 mock_project = mock.MagicMock() 102 mock_project.locations.return_value = mock_locations 103 104 mock_client = mock.MagicMock() 105 mock_client.projects.return_value = mock_project 106 return mock_client 107 108 def testEnvironmentDiscoveryUrl(self): 109 os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' 110 self.assertEqual('https://{api}.internal/{apiVersion}', 111 (client._environment_discovery_url())) 112 113 def testEnvironmentVarToNetworkEndpointsSingleIp(self): 114 self.assertEqual( 115 [{'ipAddress': '1.2.3.4', 'port': '1234'}], 116 list(client._environment_var_to_network_endpoints( 117 '1.2.3.4:1234'))) 118 119 def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self): 120 self.assertEqual( 121 [{'ipAddress': '1.2.3.4', 'port': '2000'}], 122 list( 123 client._environment_var_to_network_endpoints( 124 'grpc://1.2.3.4:2000'))) 125 126 def testEnvironmentVarToNetworkEndpointsMultipleIps(self): 127 self.assertEqual( 128 [{'ipAddress': '1.2.3.4', 'port': '2000'}, 129 {'ipAddress': '5.6.7.8', 'port': '1234'}], 130 list( 131 client._environment_var_to_network_endpoints( 132 '1.2.3.4:2000,5.6.7.8:1234'))) 133 134 def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self): 135 self.assertEqual( 136 [{'ipAddress': '1.2.3.4', 'port': '2000'}, 137 {'ipAddress': '5.6.7.8', 'port': '1234'}], 138 list(client._environment_var_to_network_endpoints( 139 'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234'))) 140 141 def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self): 142 self.assertEqual( 143 [{'ipAddress': '1.2.3.4', 'port': '2000'}, 144 {'ipAddress': '5.6.7.8', 'port': '8470'}], 145 list(client._environment_var_to_network_endpoints( 146 '1.2.3.4:2000,grpc://5.6.7.8'))) 147 148 def testInitializeNoArguments(self): 149 with self.assertRaisesRegex( 150 ValueError, 'Please provide a TPU Name to connect to.'): 151 client.Client() 152 153 def testInitializeMultiElementTpuArray(self): 154 with self.assertRaisesRegex( 155 NotImplementedError, 156 'Using multiple TPUs in a single session is not yet implemented'): 157 client.Client(tpu=['multiple', 'elements']) 158 159 def assertClientContains(self, c): 160 self.assertEqual('tpu_name', c._tpu) 161 self.assertEqual(True, c._use_api) 162 self.assertIsNone(c._credentials) 163 self.assertEqual('test-project', c._project) 164 self.assertEqual('us-central1-c', c._zone) 165 self.assertIsNone(c._discovery_url) 166 self.assertEqual([{ 167 'ipAddress': '10.1.2.3', 168 'port': '8470' 169 }], c.network_endpoints()) 170 171 @mock.patch.object(client, '_request_compute_metadata', 172 mock_request_compute_metadata) 173 def testNetworkEndpointsNotReadyWithApi(self): 174 tpu_map = { 175 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 176 'ipAddress': '10.1.2.3', 177 'port': '8470', 178 } 179 } 180 c = client.Client( 181 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 182 self.assertRaisesRegex( 183 RuntimeError, 'TPU .* is not yet ready; state: "None"', 184 c.network_endpoints) 185 186 @mock.patch.object(client, '_request_compute_metadata', 187 mock_request_compute_metadata) 188 def testInitializeNoArgumentsWithEnvironmentVariable(self): 189 os.environ['TPU_NAME'] = 'tpu_name' 190 tpu_map = { 191 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 192 'ipAddress': '10.1.2.3', 193 'port': '8470', 194 'state': 'READY', 195 'health': 'HEALTHY', 196 } 197 } 198 c = client.Client( 199 service=self.mock_service_client(tpu_map=tpu_map)) 200 self.assertClientContains(c) 201 202 @mock.patch.object(client, '_request_compute_metadata', 203 mock_request_compute_metadata) 204 def testInitializeNoArgumentsWithTPUEnvironmentVariableTPUConfig(self): 205 os.environ['TPU_CONFIG'] = json.dumps({ 206 'project': 'test-project', 207 'zone': 'us-central1-c', 208 'tpu_node_name': 'tpu_name', 209 }) 210 tpu_map = { 211 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 212 'ipAddress': '10.1.2.3', 213 'port': '8470', 214 'state': 'READY', 215 'health': 'HEALTHY', 216 } 217 } 218 c = client.Client(service=self.mock_service_client(tpu_map=tpu_map)) 219 self.assertClientContains(c) 220 221 @mock.patch.object(client, '_request_compute_metadata', 222 mock_request_compute_metadata) 223 def testInitializeTpuName(self): 224 tpu_map = { 225 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 226 'ipAddress': '10.1.2.3', 227 'port': '8470', 228 'state': 'READY', 229 'health': 'HEALTHY', 230 } 231 } 232 c = client.Client( 233 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 234 self.assertClientContains(c) 235 236 @mock.patch.object(client, '_request_compute_metadata', 237 mock_request_compute_metadata) 238 def testInitializeIpAddress(self): 239 c = client.Client(tpu='grpc://1.2.3.4:8470') 240 self.assertEqual('grpc://1.2.3.4:8470', c._tpu) 241 self.assertEqual(False, c._use_api) 242 self.assertIsNone(c._service) 243 self.assertIsNone(c._credentials) 244 self.assertIsNone(c._project) 245 self.assertIsNone(c._zone) 246 self.assertIsNone(c._discovery_url) 247 self.assertEqual([{ 248 'ipAddress': '1.2.3.4', 249 'port': '8470' 250 }], c.network_endpoints()) 251 252 def testInitializeWithoutMetadata(self): 253 c = client.Client( 254 tpu='tpu_name', project='project', zone='zone') 255 self.assertEqual('tpu_name', c._tpu) 256 self.assertEqual(True, c._use_api) 257 self.assertIsNone(c._service) 258 self.assertIsNone(c._credentials) 259 self.assertEqual('project', c._project) 260 self.assertEqual('zone', c._zone) 261 self.assertIsNone(c._discovery_url) 262 263 def testRecoverableNoApiAccess(self): 264 c = client.Client(tpu='grpc://1.2.3.4:8470') 265 self.assertEqual(True, c.recoverable()) 266 267 @mock.patch.object(client, '_request_compute_metadata', 268 mock_request_compute_metadata) 269 def testRecoverableNoState(self): 270 tpu_map = { 271 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 272 'ipAddress': '10.1.2.3', 273 'port': '8470', 274 } 275 } 276 c = client.Client( 277 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 278 self.assertEqual(True, c.recoverable()) 279 280 @mock.patch.object(client, '_request_compute_metadata', 281 mock_request_compute_metadata) 282 def testRecoverableReady(self): 283 tpu_map = { 284 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 285 'ipAddress': '10.1.2.3', 286 'port': '8470', 287 'state': 'READY', 288 } 289 } 290 c = client.Client( 291 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 292 self.assertEqual(True, c.recoverable()) 293 294 @mock.patch.object(client, '_request_compute_metadata', 295 mock_request_compute_metadata) 296 def testRecoverablePreempted(self): 297 tpu_map = { 298 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 299 'ipAddress': '10.1.2.3', 300 'port': '8470', 301 'state': 'PREEMPTED', 302 } 303 } 304 c = client.Client( 305 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 306 self.assertEqual(False, c.recoverable()) 307 308 @mock.patch.object(client, '_request_compute_metadata', 309 mock_request_compute_metadata) 310 @mock.patch.object(client, '_utcnow', mock_utcnow) 311 def testRecoverableOOM(self): 312 test_cases = [ 313 ({ 314 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 315 'state': 316 'READY', 317 } 318 }, True), 319 ({ 320 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 321 'state': 322 'READY', 323 'symptoms': [{ 324 'createTime': '2000-01-01T00:29:30.123456Z', 325 'symptomType': 'OUT_OF_MEMORY', 326 'details': 'The TPU runtime has run OOM at timestamp ' 327 '2020-05-29T04:51:32.038721+00:00', 328 'workerId': '0' 329 }] 330 } 331 }, False), 332 ({ 333 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 334 'state': 335 'READY', 336 'symptoms': [{ 337 'createTime': '2000-01-01T00:28:20.123456Z', 338 'symptomType': 'OUT_OF_MEMORY', 339 'details': 'The TPU runtime has run OOM at timestamp ' 340 '2020-05-29T04:51:32.038721+00:00', 341 'workerId': '0' 342 }] 343 } 344 }, True), 345 ({ 346 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 347 'state': 348 'READY', 349 'symptoms': [{ 350 'createTime': '2000-01-01T00:28:40.123456Z', 351 'symptomType': 'LOW_MEMORY', 352 'details': 'The TPU runtime has run OOM at timestamp ' 353 '2020-05-29T04:51:32.038721+00:00', 354 'workerId': '0' 355 }, { 356 'createTime': '2000-01-01T00:29:30.123456Z', 357 'symptomType': 'OUT_OF_MEMORY', 358 'details': 'The TPU runtime has run OOM at timestamp ' 359 '2020-05-29T04:51:32.038721+00:00', 360 'workerId': '0' 361 }, { 362 'createTime': '2000-01-01T00:29:40.123456Z', 363 'symptomType': 'LOW_MEMORY', 364 'details': 'The TPU runtime has run OOM at timestamp ' 365 '2020-05-29T04:51:32.038721+00:00', 366 'workerId': '0' 367 }] 368 } 369 }, False), 370 ({ 371 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 372 'state': 373 'READY', 374 'symptoms': [{ 375 'createTime': '2000-01-01T00:28:20.123456Z', 376 'symptomType': 'OUT_OF_MEMORY', 377 'details': 'The TPU runtime has run OOM at timestamp ' 378 '2020-05-29T04:51:32.038721+00:00', 379 'workerId': '0' 380 }, { 381 'createTime': '2000-01-01T00:29:30.123456Z', 382 'symptomType': 'LOW_MEMORY', 383 'details': 'The TPU runtime has run OOM at timestamp ' 384 '2020-05-29T04:51:32.038721+00:00', 385 'workerId': '0' 386 }, { 387 'createTime': '2000-01-01T00:29:40.123456Z', 388 'symptomType': 'LOW_MEMORY', 389 'details': 'The TPU runtime has run OOM at timestamp ' 390 '2020-05-29T04:51:32.038721+00:00', 391 'workerId': '0' 392 }] 393 } 394 }, True), 395 ({ 396 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 397 'state': 398 'READY', 399 'symptoms': [{ 400 'createTime': '2000-01-01T00:29:00.123456Z', 401 'symptomType': 'LOW_MEMORY', 402 'details': 'The TPU runtime has run OOM at timestamp ' 403 '2020-05-29T04:51:32.038721+00:00', 404 'workerId': '0' 405 }, { 406 'createTime': '2000-01-01T00:29:10.123456Z', 407 'symptomType': 'LOW_MEMORY', 408 'details': 'The TPU runtime has run OOM at timestamp ' 409 '2020-05-29T04:51:32.038721+00:00', 410 'workerId': '0' 411 }, { 412 'createTime': '2000-01-01T00:29:20.123456Z', 413 'symptomType': 'LOW_MEMORY', 414 'details': 'The TPU runtime has run OOM at timestamp ' 415 '2020-05-29T04:51:32.038721+00:00', 416 'workerId': '0' 417 }, { 418 'createTime': '2000-01-01T00:29:30.123456Z', 419 'symptomType': 'LOW_MEMORY', 420 'details': 'The TPU runtime has run OOM at timestamp ' 421 '2020-05-29T04:51:32.038721+00:00', 422 'workerId': '0' 423 }, { 424 'createTime': '2000-01-01T00:29:40.123456Z', 425 'symptomType': 'LOW_MEMORY', 426 'details': 'The TPU runtime has run OOM at timestamp ' 427 '2020-05-29T04:51:32.038721+00:00', 428 'workerId': '0' 429 }] 430 } 431 }, True) 432 ] 433 434 for tpu_map, want in test_cases: 435 c = client.Client(tpu='tpu_name', 436 service=self.mock_service_client(tpu_map=tpu_map)) 437 self.assertEqual(want, c.recoverable()) 438 439 @mock.patch.object(client, '_request_compute_metadata', 440 mock_request_compute_metadata) 441 @mock.patch.object(client, '_utcnow', mock_utcnow) 442 def testRecoverableOOMDisabled(self): 443 test_cases = [ 444 ({ 445 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 446 'state': 447 'READY', 448 'symptoms': [{ 449 'createTime': '2000-01-01T00:29:30.123456Z', 450 'symptomType': 'OUT_OF_MEMORY', 451 'details': 'The TPU runtime has run OOM at timestamp ' 452 '2020-05-29T04:51:32.038721+00:00', 453 'workerId': '0' 454 }] 455 } 456 }, True), 457 ] 458 459 FLAGS.runtime_oom_exit = False 460 for tpu_map, want in test_cases: 461 c = client.Client(tpu='tpu_name', 462 service=self.mock_service_client(tpu_map=tpu_map)) 463 self.assertEqual(want, c.recoverable()) 464 FLAGS.runtime_oom_exit = True 465 466 @mock.patch.object(client, '_request_compute_metadata', 467 mock_request_compute_metadata) 468 @mock.patch.object(client, '_utcnow', mock_utcnow) 469 def testRecoverableOOMNoAPI(self): 470 test_cases = [ 471 ({ 472 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 473 'state': 474 'READY', 475 'symptoms': [{ 476 'createTime': '2000-01-01T00:29:30.123456Z', 477 'symptomType': 'OUT_OF_MEMORY', 478 'details': 'The TPU runtime has run OOM at timestamp ' 479 '2020-05-29T04:51:32.038721+00:00', 480 'workerId': '0' 481 }] 482 } 483 }, True), 484 ] 485 486 for tpu_map, want in test_cases: 487 c = client.Client(tpu='grpc://1.2.3.4:8470', 488 service=self.mock_service_client(tpu_map=tpu_map)) 489 self.assertEqual(want, c.recoverable()) 490 491 @mock.patch.object(client, '_request_compute_metadata', 492 mock_request_compute_metadata) 493 @mock.patch.object(client, '_utcnow', mock_utcnow) 494 def testRecoverableHBMOOM(self): 495 test_cases = [ 496 ({ 497 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 498 'state': 499 'READY', 500 } 501 }, True), 502 ({ 503 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 504 'state': 505 'READY', 506 'symptoms': [{ 507 'createTime': '2000-01-01T00:29:30.123456Z', 508 'symptomType': 'HBM_OUT_OF_MEMORY', 509 'details': 'The TPU HBM has run OOM at timestamp ' 510 '2020-05-29T04:51:32.038721+00:00', 511 'workerId': '0' 512 }] 513 } 514 }, False), 515 ({ 516 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 517 'state': 518 'READY', 519 'symptoms': [{ 520 'createTime': '2000-01-01T00:28:20.123456Z', 521 'symptomType': 'HBM_OUT_OF_MEMORY', 522 'details': 'The TPU HBM has run OOM at timestamp ' 523 '2020-05-29T04:51:32.038721+00:00', 524 'workerId': '0' 525 }] 526 } 527 }, True), 528 ({ 529 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 530 'state': 531 'READY', 532 'symptoms': [{ 533 'createTime': '2000-01-01T00:28:40.123456Z', 534 'symptomType': 'LOW_MEMORY', 535 'details': 'The TPU HBM has run OOM at timestamp ' 536 '2020-05-29T04:51:32.038721+00:00', 537 'workerId': '0' 538 }, { 539 'createTime': '2000-01-01T00:29:30.123456Z', 540 'symptomType': 'HBM_OUT_OF_MEMORY', 541 'details': 'The TPU HBM has run OOM at timestamp ' 542 '2020-05-29T04:51:32.038721+00:00', 543 'workerId': '0' 544 }, { 545 'createTime': '2000-01-01T00:29:40.123456Z', 546 'symptomType': 'LOW_MEMORY', 547 'details': 'The TPU HBM has run OOM at timestamp ' 548 '2020-05-29T04:51:32.038721+00:00', 549 'workerId': '0' 550 }] 551 } 552 }, False), 553 ({ 554 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 555 'state': 556 'READY', 557 'symptoms': [{ 558 'createTime': '2000-01-01T00:28:20.123456Z', 559 'symptomType': 'HBM_OUT_OF_MEMORY', 560 'details': 'The TPU HBM has run OOM at timestamp ' 561 '2020-05-29T04:51:32.038721+00:00', 562 'workerId': '0' 563 }, { 564 'createTime': '2000-01-01T00:29:30.123456Z', 565 'symptomType': 'LOW_MEMORY', 566 'details': 'The TPU HBM has run OOM at timestamp ' 567 '2020-05-29T04:51:32.038721+00:00', 568 'workerId': '0' 569 }, { 570 'createTime': '2000-01-01T00:29:40.123456Z', 571 'symptomType': 'LOW_MEMORY', 572 'details': 'The TPU HBM has run OOM at timestamp ' 573 '2020-05-29T04:51:32.038721+00:00', 574 'workerId': '0' 575 }] 576 } 577 }, True), 578 ({ 579 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 580 'state': 581 'READY', 582 'symptoms': [{ 583 'createTime': '2000-01-01T00:29:00.123456Z', 584 'symptomType': 'LOW_MEMORY', 585 'details': 'The TPU HBM has run OOM at timestamp ' 586 '2020-05-29T04:51:32.038721+00:00', 587 'workerId': '0' 588 }, { 589 'createTime': '2000-01-01T00:29:10.123456Z', 590 'symptomType': 'LOW_MEMORY', 591 'details': 'The TPU HBM has run OOM at timestamp ' 592 '2020-05-29T04:51:32.038721+00:00', 593 'workerId': '0' 594 }, { 595 'createTime': '2000-01-01T00:29:20.123456Z', 596 'symptomType': 'LOW_MEMORY', 597 'details': 'The TPU HBM has run OOM at timestamp ' 598 '2020-05-29T04:51:32.038721+00:00', 599 'workerId': '0' 600 }, { 601 'createTime': '2000-01-01T00:29:30.123456Z', 602 'symptomType': 'LOW_MEMORY', 603 'details': 'The TPU HBM has run OOM at timestamp ' 604 '2020-05-29T04:51:32.038721+00:00', 605 'workerId': '0' 606 }, { 607 'createTime': '2000-01-01T00:29:40.123456Z', 608 'symptomType': 'LOW_MEMORY', 609 'details': 'The TPU HBM has run OOM at timestamp ' 610 '2020-05-29T04:51:32.038721+00:00', 611 'workerId': '0' 612 }] 613 } 614 }, True) 615 ] 616 617 for tpu_map, want in test_cases: 618 c = client.Client(tpu='tpu_name', 619 service=self.mock_service_client(tpu_map=tpu_map)) 620 self.assertEqual(want, c.recoverable()) 621 622 @mock.patch.object(client, '_request_compute_metadata', 623 mock_request_compute_metadata) 624 @mock.patch.object(client, '_utcnow', mock_utcnow) 625 def testRecoverableHBMOOMDisabled(self): 626 test_cases = [ 627 ({ 628 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 629 'state': 630 'READY', 631 'symptoms': [{ 632 'createTime': '2000-01-01T00:29:30.123456Z', 633 'symptomType': 'HBM_OUT_OF_MEMORY', 634 'details': 'The TPU HBM has run OOM at timestamp ' 635 '2020-05-29T04:51:32.038721+00:00', 636 'workerId': '0' 637 }] 638 } 639 }, True), 640 ] 641 642 FLAGS.hbm_oom_exit = False 643 for tpu_map, want in test_cases: 644 c = client.Client(tpu='tpu_name', 645 service=self.mock_service_client(tpu_map=tpu_map)) 646 self.assertEqual(want, c.recoverable()) 647 FLAGS.hbm_oom_exit = True 648 649 @mock.patch.object(client, '_request_compute_metadata', 650 mock_request_compute_metadata) 651 @mock.patch.object(client, '_utcnow', mock_utcnow) 652 def testRecoverableHBMOOMNoAPI(self): 653 test_cases = [ 654 ({ 655 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 656 'state': 657 'READY', 658 'symptoms': [{ 659 'createTime': '2000-01-01T00:29:30.123456Z', 660 'symptomType': 'HBM_OUT_OF_MEMORY', 661 'details': 'The TPU HBM has run OOM at timestamp ' 662 '2020-05-29T04:51:32.038721+00:00', 663 'workerId': '0' 664 }] 665 } 666 }, True), 667 ] 668 669 for tpu_map, want in test_cases: 670 c = client.Client(tpu='grpc://1.2.3.4:8470', 671 service=self.mock_service_client(tpu_map=tpu_map)) 672 self.assertEqual(want, c.recoverable()) 673 674 @mock.patch.object(client, '_request_compute_metadata', 675 mock_request_compute_metadata) 676 def testHealthApi(self): 677 tpu_map = { 678 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 679 'ipAddress': '10.1.2.3', 680 'port': '8470', 681 'state': 'PREEMPTED', 682 'health': 'HEALTHY', 683 'acceleratorType': 'v3-8', 684 'tensorflowVersion': 'nightly', 685 } 686 } 687 c = client.Client( 688 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 689 self.assertEqual('HEALTHY', c.health()) 690 691 @mock.patch.object(client, '_request_compute_metadata', 692 mock_request_compute_metadata) 693 def testRuntimeVersionApi(self): 694 tpu_map = { 695 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 696 'ipAddress': '10.1.2.3', 697 'port': '8470', 698 'state': 'PREEMPTED', 699 'health': 'HEALTHY', 700 'acceleratorType': 'v3-8', 701 'tensorflowVersion': 'nightly', 702 } 703 } 704 c = client.Client( 705 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 706 self.assertEqual('nightly', c.runtime_version()) 707 708 @mock.patch.object(client, '_request_compute_metadata', 709 mock_request_compute_metadata) 710 def testAcceleratorTypeApi(self): 711 tpu_map = { 712 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 713 'ipAddress': '10.1.2.3', 714 'port': '8470', 715 'state': 'PREEMPTED', 716 'health': 'HEALTHY', 717 'acceleratorType': 'v3-8', 718 'tensorflowVersion': 'nightly', 719 } 720 } 721 c = client.Client( 722 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 723 self.assertEqual('v3-8', c.accelerator_type()) 724 725 def testHandlesByteStrings(self): 726 self.assertEqual( 727 client.Client( 728 tpu='tpu_name', zone='zone', project='project')._full_name(), 729 client.Client( 730 tpu=b'tpu_name', zone=b'zone', project=b'project')._full_name(), 731 ) 732 733 @mock.patch.object(client, '_request_compute_metadata', 734 mock_request_compute_metadata) 735 def testWaitForHealthy(self): 736 time_mock = mock.patch.object(time, 'time', autospec=True).start() 737 time_mock.side_effect = self._mock_time 738 sleep_mock = mock.patch.object(time, 'sleep', autospec=True).start() 739 sleep_mock.side_effect = self._mock_sleep 740 741 health_timeseries = (['UNHEALTHY_MAINTENANCE']*30 + ['TIMEOUT']*10 742 + [None]*20 + ['HEALTHY']*30) 743 tpu_map = { 744 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 745 'ipAddress': '10.1.2.3', 746 'port': '8470', 747 'state': 'READY', 748 'health': health_timeseries, 749 }, 750 } 751 752 c = client.Client( 753 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 754 755 # Doesn't throw RuntimeError as TPU becomes HEALTHY before timeout 756 timeout = 80 757 interval = 5 758 return_time = 60 759 c.wait_for_healthy(timeout_s=timeout, interval=interval) 760 self.assertEqual(time.time(), return_time) 761 self.assertEqual(sleep_mock.call_count, return_time/interval) 762 763 @mock.patch.object(client, '_request_compute_metadata', 764 mock_request_compute_metadata) 765 def testWaitForHealthyRaisesError(self): 766 time_mock = mock.patch.object(time, 'time', autospec=True).start() 767 time_mock.side_effect = self._mock_time 768 sleep_mock = mock.patch.object(time, 'sleep', autospec=True).start() 769 sleep_mock.side_effect = self._mock_sleep 770 771 # Mock timeseries where takes longer than timeout. 772 health_timeseries = ['UNHEALTHY_MAINTENANCE']*50 + ['TIMEOUT']*50 773 tpu_map = { 774 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 775 'ipAddress': '10.1.2.3', 776 'port': '8470', 777 'state': 'READY', 778 'health': health_timeseries, 779 }, 780 } 781 782 c = client.Client( 783 tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map)) 784 785 # Doesn't throw RuntimeError as TPU becomes HEALTHY before timeout 786 with self.assertRaisesRegex( 787 RuntimeError, 788 'Timed out waiting for TPU .* to become healthy'): 789 c.wait_for_healthy(timeout_s=80, interval=5) 790 791 def baseConfigureTpuVersion(self): 792 tpu_map = { 793 'projects/test-project/locations/us-central1-c/nodes/tpu_name': { 794 'state': 795 'READY', 796 'networkEndpoints': [ 797 { 798 'ipAddress': '1.2.3.4' 799 }, 800 { 801 'ipAddress': '5.6.7.8' 802 }, 803 ] 804 } 805 } 806 return client.Client( 807 tpu='tpu_name', 808 project='test-project', 809 zone='us-central1-c', 810 service=self.mock_service_client(tpu_map=tpu_map)) 811 812 @mock.patch.object(urllib.request, 'urlopen') 813 def testConfigureTpuVersion(self, urlopen): 814 c = self.baseConfigureTpuVersion() 815 c.configure_tpu_version('1.15') 816 paths = [call[0][0].full_url for call in urlopen.call_args_list] 817 self.assertCountEqual([ 818 'http://1.2.3.4:8475/requestversion/1.15?restartType=always', 819 'http://5.6.7.8:8475/requestversion/1.15?restartType=always' 820 ], sorted(paths)) 821 822 @mock.patch.object(urllib.request, 'urlopen') 823 def testConfigureTpuVersionRestartIfneeded(self, urlopen): 824 c = self.baseConfigureTpuVersion() 825 c.configure_tpu_version('1.15', restart_type='ifNeeded') 826 paths = [call[0][0].full_url for call in urlopen.call_args_list] 827 self.assertCountEqual([ 828 'http://1.2.3.4:8475/requestversion/1.15?restartType=ifNeeded', 829 'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded' 830 ], sorted(paths)) 831 832 @mock.patch.object(urllib.request, 'urlopen') 833 def testGetTpuVersion(self, urlopen): 834 c = client.Client( 835 tpu='grpc://1.2.3.4:8470') 836 resp = mock.Mock() 837 resp.read.side_effect = ['{}', '{"currentVersion": "someVersion"}'] 838 urlopen.return_value = resp 839 self.assertIsNone(c.runtime_version(), 'Missing key should be handled.') 840 self.assertEqual( 841 'someVersion', c.runtime_version(), 'Should return configured version.') 842 paths = [call[0][0].full_url for call in urlopen.call_args_list] 843 self.assertCountEqual([ 844 'http://1.2.3.4:8475/requestversion', 845 'http://1.2.3.4:8475/requestversion', 846 ], sorted(paths)) 847 848 849if __name__ == '__main__': 850 test.main() 851