• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for cloud tpu client."""
16
17import datetime
18import json
19import os
20import time
21import urllib
22
23from absl import flags
24
25from tensorflow.python.platform import test
26from tensorflow.python.tpu.client import client
27
28FLAGS = flags.FLAGS
29
30mock = test.mock
31
32_UTCNOW_STR = '2000-01-01T00:30:00'
33
34
35def mock_utcnow():
36  return datetime.datetime.strptime(_UTCNOW_STR, '%Y-%m-%dT%H:%M:%S')
37
38
39def mock_request_compute_metadata(path):
40  if path == 'project/project-id':
41    return 'test-project'
42  elif path == 'instance/zone':
43    return 'projects/test-project/locations/us-central1-c'
44  elif path == 'instance/network-interfaces/0/ip':
45    return '10.128.1.2'
46  return ''
47
48
49class MockRequestClass:
50
51  def __init__(self, name, tpu_map):
52    self._name = name
53    self._tpu_map = tpu_map
54
55  def execute(self):
56    if self._name in self._tpu_map:
57      tpu_dict = self._tpu_map[self._name].copy()
58      if isinstance(tpu_dict.get('health'), list):
59        # Do extraction of health list to a single health string based on time.
60        time_now = time.time()
61        health_now = tpu_dict.get('health')[time_now]
62        tpu_dict['health'] = health_now
63      return tpu_dict
64    else:
65      raise KeyError('Resource %s was not found' % self._name)
66
67
68class MockNodeClass:
69
70  def __init__(self, tpu_map):
71    self._tpu_map = tpu_map
72
73  def get(self, name):
74    return MockRequestClass(name, self._tpu_map)
75
76
77class CloudTpuClientTest(test.TestCase):
78
79  def setUp(self):
80    super().setUp()
81    if 'TPU_API_DISCOVERY_URL' in os.environ:
82      del os.environ['TPU_API_DISCOVERY_URL']
83    if 'TPU_NAME' in os.environ:
84      del os.environ['TPU_NAME']
85    self._time_now = 0
86    self.addCleanup(mock.patch.stopall)
87
88  def _mock_time(self, *args, **kwargs):
89    return self._time_now
90
91  def _mock_sleep(self, secs):
92    self._time_now += secs
93
94  def mock_service_client(self, tpu_map=None):
95    if tpu_map is None:
96      tpu_map = {}
97
98    mock_locations = mock.MagicMock()
99    mock_locations.nodes.return_value = MockNodeClass(tpu_map)
100
101    mock_project = mock.MagicMock()
102    mock_project.locations.return_value = mock_locations
103
104    mock_client = mock.MagicMock()
105    mock_client.projects.return_value = mock_project
106    return mock_client
107
108  def testEnvironmentDiscoveryUrl(self):
109    os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
110    self.assertEqual('https://{api}.internal/{apiVersion}',
111                     (client._environment_discovery_url()))
112
113  def testEnvironmentVarToNetworkEndpointsSingleIp(self):
114    self.assertEqual(
115        [{'ipAddress': '1.2.3.4', 'port': '1234'}],
116        list(client._environment_var_to_network_endpoints(
117            '1.2.3.4:1234')))
118
119  def testEnvironmentVarToNetworkEndpointsSingleGrpcAddress(self):
120    self.assertEqual(
121        [{'ipAddress': '1.2.3.4', 'port': '2000'}],
122        list(
123            client._environment_var_to_network_endpoints(
124                'grpc://1.2.3.4:2000')))
125
126  def testEnvironmentVarToNetworkEndpointsMultipleIps(self):
127    self.assertEqual(
128        [{'ipAddress': '1.2.3.4', 'port': '2000'},
129         {'ipAddress': '5.6.7.8', 'port': '1234'}],
130        list(
131            client._environment_var_to_network_endpoints(
132                '1.2.3.4:2000,5.6.7.8:1234')))
133
134  def testEnvironmentVarToNetworkEndpointsMultipleGrpcAddresses(self):
135    self.assertEqual(
136        [{'ipAddress': '1.2.3.4', 'port': '2000'},
137         {'ipAddress': '5.6.7.8', 'port': '1234'}],
138        list(client._environment_var_to_network_endpoints(
139            'grpc://1.2.3.4:2000,grpc://5.6.7.8:1234')))
140
141  def testEnvironmentVarToNetworkEndpointsMissingPortAndMixed(self):
142    self.assertEqual(
143        [{'ipAddress': '1.2.3.4', 'port': '2000'},
144         {'ipAddress': '5.6.7.8', 'port': '8470'}],
145        list(client._environment_var_to_network_endpoints(
146            '1.2.3.4:2000,grpc://5.6.7.8')))
147
148  def testInitializeNoArguments(self):
149    with self.assertRaisesRegex(
150        ValueError, 'Please provide a TPU Name to connect to.'):
151      client.Client()
152
153  def testInitializeMultiElementTpuArray(self):
154    with self.assertRaisesRegex(
155        NotImplementedError,
156        'Using multiple TPUs in a single session is not yet implemented'):
157      client.Client(tpu=['multiple', 'elements'])
158
159  def assertClientContains(self, c):
160    self.assertEqual('tpu_name', c._tpu)
161    self.assertEqual(True, c._use_api)
162    self.assertIsNone(c._credentials)
163    self.assertEqual('test-project', c._project)
164    self.assertEqual('us-central1-c', c._zone)
165    self.assertIsNone(c._discovery_url)
166    self.assertEqual([{
167        'ipAddress': '10.1.2.3',
168        'port': '8470'
169    }], c.network_endpoints())
170
171  @mock.patch.object(client, '_request_compute_metadata',
172                     mock_request_compute_metadata)
173  def testNetworkEndpointsNotReadyWithApi(self):
174    tpu_map = {
175        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
176            'ipAddress': '10.1.2.3',
177            'port': '8470',
178        }
179    }
180    c = client.Client(
181        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
182    self.assertRaisesRegex(
183        RuntimeError, 'TPU .* is not yet ready; state: "None"',
184        c.network_endpoints)
185
186  @mock.patch.object(client, '_request_compute_metadata',
187                     mock_request_compute_metadata)
188  def testInitializeNoArgumentsWithEnvironmentVariable(self):
189    os.environ['TPU_NAME'] = 'tpu_name'
190    tpu_map = {
191        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
192            'ipAddress': '10.1.2.3',
193            'port': '8470',
194            'state': 'READY',
195            'health': 'HEALTHY',
196        }
197    }
198    c = client.Client(
199        service=self.mock_service_client(tpu_map=tpu_map))
200    self.assertClientContains(c)
201
202  @mock.patch.object(client, '_request_compute_metadata',
203                     mock_request_compute_metadata)
204  def testInitializeNoArgumentsWithTPUEnvironmentVariableTPUConfig(self):
205    os.environ['TPU_CONFIG'] = json.dumps({
206        'project': 'test-project',
207        'zone': 'us-central1-c',
208        'tpu_node_name': 'tpu_name',
209    })
210    tpu_map = {
211        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
212            'ipAddress': '10.1.2.3',
213            'port': '8470',
214            'state': 'READY',
215            'health': 'HEALTHY',
216        }
217    }
218    c = client.Client(service=self.mock_service_client(tpu_map=tpu_map))
219    self.assertClientContains(c)
220
221  @mock.patch.object(client, '_request_compute_metadata',
222                     mock_request_compute_metadata)
223  def testInitializeTpuName(self):
224    tpu_map = {
225        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
226            'ipAddress': '10.1.2.3',
227            'port': '8470',
228            'state': 'READY',
229            'health': 'HEALTHY',
230        }
231    }
232    c = client.Client(
233        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
234    self.assertClientContains(c)
235
236  @mock.patch.object(client, '_request_compute_metadata',
237                     mock_request_compute_metadata)
238  def testInitializeIpAddress(self):
239    c = client.Client(tpu='grpc://1.2.3.4:8470')
240    self.assertEqual('grpc://1.2.3.4:8470', c._tpu)
241    self.assertEqual(False, c._use_api)
242    self.assertIsNone(c._service)
243    self.assertIsNone(c._credentials)
244    self.assertIsNone(c._project)
245    self.assertIsNone(c._zone)
246    self.assertIsNone(c._discovery_url)
247    self.assertEqual([{
248        'ipAddress': '1.2.3.4',
249        'port': '8470'
250    }], c.network_endpoints())
251
252  def testInitializeWithoutMetadata(self):
253    c = client.Client(
254        tpu='tpu_name', project='project', zone='zone')
255    self.assertEqual('tpu_name', c._tpu)
256    self.assertEqual(True, c._use_api)
257    self.assertIsNone(c._service)
258    self.assertIsNone(c._credentials)
259    self.assertEqual('project', c._project)
260    self.assertEqual('zone', c._zone)
261    self.assertIsNone(c._discovery_url)
262
263  def testRecoverableNoApiAccess(self):
264    c = client.Client(tpu='grpc://1.2.3.4:8470')
265    self.assertEqual(True, c.recoverable())
266
267  @mock.patch.object(client, '_request_compute_metadata',
268                     mock_request_compute_metadata)
269  def testRecoverableNoState(self):
270    tpu_map = {
271        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
272            'ipAddress': '10.1.2.3',
273            'port': '8470',
274        }
275    }
276    c = client.Client(
277        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
278    self.assertEqual(True, c.recoverable())
279
280  @mock.patch.object(client, '_request_compute_metadata',
281                     mock_request_compute_metadata)
282  def testRecoverableReady(self):
283    tpu_map = {
284        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
285            'ipAddress': '10.1.2.3',
286            'port': '8470',
287            'state': 'READY',
288        }
289    }
290    c = client.Client(
291        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
292    self.assertEqual(True, c.recoverable())
293
294  @mock.patch.object(client, '_request_compute_metadata',
295                     mock_request_compute_metadata)
296  def testRecoverablePreempted(self):
297    tpu_map = {
298        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
299            'ipAddress': '10.1.2.3',
300            'port': '8470',
301            'state': 'PREEMPTED',
302        }
303    }
304    c = client.Client(
305        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
306    self.assertEqual(False, c.recoverable())
307
308  @mock.patch.object(client, '_request_compute_metadata',
309                     mock_request_compute_metadata)
310  @mock.patch.object(client, '_utcnow', mock_utcnow)
311  def testRecoverableOOM(self):
312    test_cases = [
313        ({
314            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
315                'state':
316                    'READY',
317            }
318        }, True),
319        ({
320            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
321                'state':
322                    'READY',
323                'symptoms': [{
324                    'createTime': '2000-01-01T00:29:30.123456Z',
325                    'symptomType': 'OUT_OF_MEMORY',
326                    'details': 'The TPU runtime has run OOM at timestamp '
327                               '2020-05-29T04:51:32.038721+00:00',
328                    'workerId': '0'
329                }]
330            }
331        }, False),
332        ({
333            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
334                'state':
335                    'READY',
336                'symptoms': [{
337                    'createTime': '2000-01-01T00:28:20.123456Z',
338                    'symptomType': 'OUT_OF_MEMORY',
339                    'details': 'The TPU runtime has run OOM at timestamp '
340                               '2020-05-29T04:51:32.038721+00:00',
341                    'workerId': '0'
342                }]
343            }
344        }, True),
345        ({
346            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
347                'state':
348                    'READY',
349                'symptoms': [{
350                    'createTime': '2000-01-01T00:28:40.123456Z',
351                    'symptomType': 'LOW_MEMORY',
352                    'details': 'The TPU runtime has run OOM at timestamp '
353                               '2020-05-29T04:51:32.038721+00:00',
354                    'workerId': '0'
355                }, {
356                    'createTime': '2000-01-01T00:29:30.123456Z',
357                    'symptomType': 'OUT_OF_MEMORY',
358                    'details': 'The TPU runtime has run OOM at timestamp '
359                               '2020-05-29T04:51:32.038721+00:00',
360                    'workerId': '0'
361                }, {
362                    'createTime': '2000-01-01T00:29:40.123456Z',
363                    'symptomType': 'LOW_MEMORY',
364                    'details': 'The TPU runtime has run OOM at timestamp '
365                               '2020-05-29T04:51:32.038721+00:00',
366                    'workerId': '0'
367                }]
368            }
369        }, False),
370        ({
371            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
372                'state':
373                    'READY',
374                'symptoms': [{
375                    'createTime': '2000-01-01T00:28:20.123456Z',
376                    'symptomType': 'OUT_OF_MEMORY',
377                    'details': 'The TPU runtime has run OOM at timestamp '
378                               '2020-05-29T04:51:32.038721+00:00',
379                    'workerId': '0'
380                }, {
381                    'createTime': '2000-01-01T00:29:30.123456Z',
382                    'symptomType': 'LOW_MEMORY',
383                    'details': 'The TPU runtime has run OOM at timestamp '
384                               '2020-05-29T04:51:32.038721+00:00',
385                    'workerId': '0'
386                }, {
387                    'createTime': '2000-01-01T00:29:40.123456Z',
388                    'symptomType': 'LOW_MEMORY',
389                    'details': 'The TPU runtime has run OOM at timestamp '
390                               '2020-05-29T04:51:32.038721+00:00',
391                    'workerId': '0'
392                }]
393            }
394        }, True),
395        ({
396            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
397                'state':
398                    'READY',
399                'symptoms': [{
400                    'createTime': '2000-01-01T00:29:00.123456Z',
401                    'symptomType': 'LOW_MEMORY',
402                    'details': 'The TPU runtime has run OOM at timestamp '
403                               '2020-05-29T04:51:32.038721+00:00',
404                    'workerId': '0'
405                }, {
406                    'createTime': '2000-01-01T00:29:10.123456Z',
407                    'symptomType': 'LOW_MEMORY',
408                    'details': 'The TPU runtime has run OOM at timestamp '
409                               '2020-05-29T04:51:32.038721+00:00',
410                    'workerId': '0'
411                }, {
412                    'createTime': '2000-01-01T00:29:20.123456Z',
413                    'symptomType': 'LOW_MEMORY',
414                    'details': 'The TPU runtime has run OOM at timestamp '
415                               '2020-05-29T04:51:32.038721+00:00',
416                    'workerId': '0'
417                }, {
418                    'createTime': '2000-01-01T00:29:30.123456Z',
419                    'symptomType': 'LOW_MEMORY',
420                    'details': 'The TPU runtime has run OOM at timestamp '
421                               '2020-05-29T04:51:32.038721+00:00',
422                    'workerId': '0'
423                }, {
424                    'createTime': '2000-01-01T00:29:40.123456Z',
425                    'symptomType': 'LOW_MEMORY',
426                    'details': 'The TPU runtime has run OOM at timestamp '
427                               '2020-05-29T04:51:32.038721+00:00',
428                    'workerId': '0'
429                }]
430            }
431        }, True)
432    ]
433
434    for tpu_map, want in test_cases:
435      c = client.Client(tpu='tpu_name',
436                        service=self.mock_service_client(tpu_map=tpu_map))
437      self.assertEqual(want, c.recoverable())
438
439  @mock.patch.object(client, '_request_compute_metadata',
440                     mock_request_compute_metadata)
441  @mock.patch.object(client, '_utcnow', mock_utcnow)
442  def testRecoverableOOMDisabled(self):
443    test_cases = [
444        ({
445            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
446                'state':
447                    'READY',
448                'symptoms': [{
449                    'createTime': '2000-01-01T00:29:30.123456Z',
450                    'symptomType': 'OUT_OF_MEMORY',
451                    'details': 'The TPU runtime has run OOM at timestamp '
452                               '2020-05-29T04:51:32.038721+00:00',
453                    'workerId': '0'
454                }]
455            }
456        }, True),
457    ]
458
459    FLAGS.runtime_oom_exit = False
460    for tpu_map, want in test_cases:
461      c = client.Client(tpu='tpu_name',
462                        service=self.mock_service_client(tpu_map=tpu_map))
463      self.assertEqual(want, c.recoverable())
464    FLAGS.runtime_oom_exit = True
465
466  @mock.patch.object(client, '_request_compute_metadata',
467                     mock_request_compute_metadata)
468  @mock.patch.object(client, '_utcnow', mock_utcnow)
469  def testRecoverableOOMNoAPI(self):
470    test_cases = [
471        ({
472            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
473                'state':
474                    'READY',
475                'symptoms': [{
476                    'createTime': '2000-01-01T00:29:30.123456Z',
477                    'symptomType': 'OUT_OF_MEMORY',
478                    'details': 'The TPU runtime has run OOM at timestamp '
479                               '2020-05-29T04:51:32.038721+00:00',
480                    'workerId': '0'
481                }]
482            }
483        }, True),
484    ]
485
486    for tpu_map, want in test_cases:
487      c = client.Client(tpu='grpc://1.2.3.4:8470',
488                        service=self.mock_service_client(tpu_map=tpu_map))
489      self.assertEqual(want, c.recoverable())
490
491  @mock.patch.object(client, '_request_compute_metadata',
492                     mock_request_compute_metadata)
493  @mock.patch.object(client, '_utcnow', mock_utcnow)
494  def testRecoverableHBMOOM(self):
495    test_cases = [
496        ({
497            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
498                'state':
499                    'READY',
500            }
501        }, True),
502        ({
503            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
504                'state':
505                    'READY',
506                'symptoms': [{
507                    'createTime': '2000-01-01T00:29:30.123456Z',
508                    'symptomType': 'HBM_OUT_OF_MEMORY',
509                    'details': 'The TPU HBM has run OOM at timestamp '
510                               '2020-05-29T04:51:32.038721+00:00',
511                    'workerId': '0'
512                }]
513            }
514        }, False),
515        ({
516            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
517                'state':
518                    'READY',
519                'symptoms': [{
520                    'createTime': '2000-01-01T00:28:20.123456Z',
521                    'symptomType': 'HBM_OUT_OF_MEMORY',
522                    'details': 'The TPU HBM has run OOM at timestamp '
523                               '2020-05-29T04:51:32.038721+00:00',
524                    'workerId': '0'
525                }]
526            }
527        }, True),
528        ({
529            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
530                'state':
531                    'READY',
532                'symptoms': [{
533                    'createTime': '2000-01-01T00:28:40.123456Z',
534                    'symptomType': 'LOW_MEMORY',
535                    'details': 'The TPU HBM has run OOM at timestamp '
536                               '2020-05-29T04:51:32.038721+00:00',
537                    'workerId': '0'
538                }, {
539                    'createTime': '2000-01-01T00:29:30.123456Z',
540                    'symptomType': 'HBM_OUT_OF_MEMORY',
541                    'details': 'The TPU HBM has run OOM at timestamp '
542                               '2020-05-29T04:51:32.038721+00:00',
543                    'workerId': '0'
544                }, {
545                    'createTime': '2000-01-01T00:29:40.123456Z',
546                    'symptomType': 'LOW_MEMORY',
547                    'details': 'The TPU HBM has run OOM at timestamp '
548                               '2020-05-29T04:51:32.038721+00:00',
549                    'workerId': '0'
550                }]
551            }
552        }, False),
553        ({
554            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
555                'state':
556                    'READY',
557                'symptoms': [{
558                    'createTime': '2000-01-01T00:28:20.123456Z',
559                    'symptomType': 'HBM_OUT_OF_MEMORY',
560                    'details': 'The TPU HBM has run OOM at timestamp '
561                               '2020-05-29T04:51:32.038721+00:00',
562                    'workerId': '0'
563                }, {
564                    'createTime': '2000-01-01T00:29:30.123456Z',
565                    'symptomType': 'LOW_MEMORY',
566                    'details': 'The TPU HBM has run OOM at timestamp '
567                               '2020-05-29T04:51:32.038721+00:00',
568                    'workerId': '0'
569                }, {
570                    'createTime': '2000-01-01T00:29:40.123456Z',
571                    'symptomType': 'LOW_MEMORY',
572                    'details': 'The TPU HBM has run OOM at timestamp '
573                               '2020-05-29T04:51:32.038721+00:00',
574                    'workerId': '0'
575                }]
576            }
577        }, True),
578        ({
579            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
580                'state':
581                    'READY',
582                'symptoms': [{
583                    'createTime': '2000-01-01T00:29:00.123456Z',
584                    'symptomType': 'LOW_MEMORY',
585                    'details': 'The TPU HBM has run OOM at timestamp '
586                               '2020-05-29T04:51:32.038721+00:00',
587                    'workerId': '0'
588                }, {
589                    'createTime': '2000-01-01T00:29:10.123456Z',
590                    'symptomType': 'LOW_MEMORY',
591                    'details': 'The TPU HBM has run OOM at timestamp '
592                               '2020-05-29T04:51:32.038721+00:00',
593                    'workerId': '0'
594                }, {
595                    'createTime': '2000-01-01T00:29:20.123456Z',
596                    'symptomType': 'LOW_MEMORY',
597                    'details': 'The TPU HBM has run OOM at timestamp '
598                               '2020-05-29T04:51:32.038721+00:00',
599                    'workerId': '0'
600                }, {
601                    'createTime': '2000-01-01T00:29:30.123456Z',
602                    'symptomType': 'LOW_MEMORY',
603                    'details': 'The TPU HBM has run OOM at timestamp '
604                               '2020-05-29T04:51:32.038721+00:00',
605                    'workerId': '0'
606                }, {
607                    'createTime': '2000-01-01T00:29:40.123456Z',
608                    'symptomType': 'LOW_MEMORY',
609                    'details': 'The TPU HBM has run OOM at timestamp '
610                               '2020-05-29T04:51:32.038721+00:00',
611                    'workerId': '0'
612                }]
613            }
614        }, True)
615    ]
616
617    for tpu_map, want in test_cases:
618      c = client.Client(tpu='tpu_name',
619                        service=self.mock_service_client(tpu_map=tpu_map))
620      self.assertEqual(want, c.recoverable())
621
622  @mock.patch.object(client, '_request_compute_metadata',
623                     mock_request_compute_metadata)
624  @mock.patch.object(client, '_utcnow', mock_utcnow)
625  def testRecoverableHBMOOMDisabled(self):
626    test_cases = [
627        ({
628            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
629                'state':
630                    'READY',
631                'symptoms': [{
632                    'createTime': '2000-01-01T00:29:30.123456Z',
633                    'symptomType': 'HBM_OUT_OF_MEMORY',
634                    'details': 'The TPU HBM has run OOM at timestamp '
635                               '2020-05-29T04:51:32.038721+00:00',
636                    'workerId': '0'
637                }]
638            }
639        }, True),
640    ]
641
642    FLAGS.hbm_oom_exit = False
643    for tpu_map, want in test_cases:
644      c = client.Client(tpu='tpu_name',
645                        service=self.mock_service_client(tpu_map=tpu_map))
646      self.assertEqual(want, c.recoverable())
647    FLAGS.hbm_oom_exit = True
648
649  @mock.patch.object(client, '_request_compute_metadata',
650                     mock_request_compute_metadata)
651  @mock.patch.object(client, '_utcnow', mock_utcnow)
652  def testRecoverableHBMOOMNoAPI(self):
653    test_cases = [
654        ({
655            'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
656                'state':
657                    'READY',
658                'symptoms': [{
659                    'createTime': '2000-01-01T00:29:30.123456Z',
660                    'symptomType': 'HBM_OUT_OF_MEMORY',
661                    'details': 'The TPU HBM has run OOM at timestamp '
662                               '2020-05-29T04:51:32.038721+00:00',
663                    'workerId': '0'
664                }]
665            }
666        }, True),
667    ]
668
669    for tpu_map, want in test_cases:
670      c = client.Client(tpu='grpc://1.2.3.4:8470',
671                        service=self.mock_service_client(tpu_map=tpu_map))
672      self.assertEqual(want, c.recoverable())
673
674  @mock.patch.object(client, '_request_compute_metadata',
675                     mock_request_compute_metadata)
676  def testHealthApi(self):
677    tpu_map = {
678        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
679            'ipAddress': '10.1.2.3',
680            'port': '8470',
681            'state': 'PREEMPTED',
682            'health': 'HEALTHY',
683            'acceleratorType': 'v3-8',
684            'tensorflowVersion': 'nightly',
685        }
686    }
687    c = client.Client(
688        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
689    self.assertEqual('HEALTHY', c.health())
690
691  @mock.patch.object(client, '_request_compute_metadata',
692                     mock_request_compute_metadata)
693  def testRuntimeVersionApi(self):
694    tpu_map = {
695        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
696            'ipAddress': '10.1.2.3',
697            'port': '8470',
698            'state': 'PREEMPTED',
699            'health': 'HEALTHY',
700            'acceleratorType': 'v3-8',
701            'tensorflowVersion': 'nightly',
702        }
703    }
704    c = client.Client(
705        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
706    self.assertEqual('nightly', c.runtime_version())
707
708  @mock.patch.object(client, '_request_compute_metadata',
709                     mock_request_compute_metadata)
710  def testAcceleratorTypeApi(self):
711    tpu_map = {
712        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
713            'ipAddress': '10.1.2.3',
714            'port': '8470',
715            'state': 'PREEMPTED',
716            'health': 'HEALTHY',
717            'acceleratorType': 'v3-8',
718            'tensorflowVersion': 'nightly',
719        }
720    }
721    c = client.Client(
722        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
723    self.assertEqual('v3-8', c.accelerator_type())
724
725  def testHandlesByteStrings(self):
726    self.assertEqual(
727        client.Client(
728            tpu='tpu_name', zone='zone', project='project')._full_name(),
729        client.Client(
730            tpu=b'tpu_name', zone=b'zone', project=b'project')._full_name(),
731    )
732
733  @mock.patch.object(client, '_request_compute_metadata',
734                     mock_request_compute_metadata)
735  def testWaitForHealthy(self):
736    time_mock = mock.patch.object(time, 'time', autospec=True).start()
737    time_mock.side_effect = self._mock_time
738    sleep_mock = mock.patch.object(time, 'sleep', autospec=True).start()
739    sleep_mock.side_effect = self._mock_sleep
740
741    health_timeseries = (['UNHEALTHY_MAINTENANCE']*30 + ['TIMEOUT']*10
742                         + [None]*20 + ['HEALTHY']*30)
743    tpu_map = {
744        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
745            'ipAddress': '10.1.2.3',
746            'port': '8470',
747            'state': 'READY',
748            'health': health_timeseries,
749        },
750    }
751
752    c = client.Client(
753        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
754
755    # Doesn't throw RuntimeError as TPU becomes HEALTHY before timeout
756    timeout = 80
757    interval = 5
758    return_time = 60
759    c.wait_for_healthy(timeout_s=timeout, interval=interval)
760    self.assertEqual(time.time(), return_time)
761    self.assertEqual(sleep_mock.call_count, return_time/interval)
762
763  @mock.patch.object(client, '_request_compute_metadata',
764                     mock_request_compute_metadata)
765  def testWaitForHealthyRaisesError(self):
766    time_mock = mock.patch.object(time, 'time', autospec=True).start()
767    time_mock.side_effect = self._mock_time
768    sleep_mock = mock.patch.object(time, 'sleep', autospec=True).start()
769    sleep_mock.side_effect = self._mock_sleep
770
771    # Mock timeseries where takes longer than timeout.
772    health_timeseries = ['UNHEALTHY_MAINTENANCE']*50 + ['TIMEOUT']*50
773    tpu_map = {
774        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
775            'ipAddress': '10.1.2.3',
776            'port': '8470',
777            'state': 'READY',
778            'health': health_timeseries,
779        },
780    }
781
782    c = client.Client(
783        tpu='tpu_name', service=self.mock_service_client(tpu_map=tpu_map))
784
785    # Doesn't throw RuntimeError as TPU becomes HEALTHY before timeout
786    with self.assertRaisesRegex(
787        RuntimeError,
788        'Timed out waiting for TPU .* to become healthy'):
789      c.wait_for_healthy(timeout_s=80, interval=5)
790
791  def baseConfigureTpuVersion(self):
792    tpu_map = {
793        'projects/test-project/locations/us-central1-c/nodes/tpu_name': {
794            'state':
795                'READY',
796            'networkEndpoints': [
797                {
798                    'ipAddress': '1.2.3.4'
799                },
800                {
801                    'ipAddress': '5.6.7.8'
802                },
803            ]
804        }
805    }
806    return client.Client(
807        tpu='tpu_name',
808        project='test-project',
809        zone='us-central1-c',
810        service=self.mock_service_client(tpu_map=tpu_map))
811
812  @mock.patch.object(urllib.request, 'urlopen')
813  def testConfigureTpuVersion(self, urlopen):
814    c = self.baseConfigureTpuVersion()
815    c.configure_tpu_version('1.15')
816    paths = [call[0][0].full_url for call in urlopen.call_args_list]
817    self.assertCountEqual([
818        'http://1.2.3.4:8475/requestversion/1.15?restartType=always',
819        'http://5.6.7.8:8475/requestversion/1.15?restartType=always'
820    ], sorted(paths))
821
822  @mock.patch.object(urllib.request, 'urlopen')
823  def testConfigureTpuVersionRestartIfneeded(self, urlopen):
824    c = self.baseConfigureTpuVersion()
825    c.configure_tpu_version('1.15', restart_type='ifNeeded')
826    paths = [call[0][0].full_url for call in urlopen.call_args_list]
827    self.assertCountEqual([
828        'http://1.2.3.4:8475/requestversion/1.15?restartType=ifNeeded',
829        'http://5.6.7.8:8475/requestversion/1.15?restartType=ifNeeded'
830    ], sorted(paths))
831
832  @mock.patch.object(urllib.request, 'urlopen')
833  def testGetTpuVersion(self, urlopen):
834    c = client.Client(
835        tpu='grpc://1.2.3.4:8470')
836    resp = mock.Mock()
837    resp.read.side_effect = ['{}', '{"currentVersion": "someVersion"}']
838    urlopen.return_value = resp
839    self.assertIsNone(c.runtime_version(), 'Missing key should be handled.')
840    self.assertEqual(
841        'someVersion', c.runtime_version(), 'Should return configured version.')
842    paths = [call[0][0].full_url for call in urlopen.call_args_list]
843    self.assertCountEqual([
844        'http://1.2.3.4:8475/requestversion',
845        'http://1.2.3.4:8475/requestversion',
846    ], sorted(paths))
847
848
849if __name__ == '__main__':
850  test.main()
851