1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU specific APIs to be used in conjunction with TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.client import session as session_lib 23from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import device as tf_device 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.tpu import functional as tpu_functional_ops 31from tensorflow.python.tpu import topology 32from tensorflow.python.tpu import tpu 33from tensorflow.python.util import compat 34from tensorflow.python.util.tf_export import tf_export 35 36 37def get_first_tpu_host_device(cluster_resolver): 38 """Get the device spec for the first TPU host.""" 39 if context.executing_eagerly(): 40 tpu_devices = sorted( 41 [x for x in context.list_devices() if "device:TPU:" in x]) 42 if not tpu_devices: 43 raise RuntimeError("Could not find any TPU devices") 44 spec = tf_device.DeviceSpec.from_string(tpu_devices[0]) 45 task_id = spec.task 46 else: 47 # Session master needs to be configured and the coordinator is not part 48 # of the cluster. 49 task_id = 0 50 if cluster_resolver.get_master() in ("", "local"): 51 return "/replica:0/task:0/device:CPU:0" 52 job_name = cluster_resolver.get_job_name() or "tpu_worker" 53 return "/job:%s/task:%d/device:CPU:0" % (job_name, task_id) 54 55 56@tf_export("tpu.experimental.initialize_tpu_system") 57def initialize_tpu_system(cluster_resolver=None): 58 """Initialize the TPU devices. 59 60 Args: 61 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 62 which provides information about the TPU cluster. 63 Returns: 64 The tf.tpu.Topology object for the topology of the TPU cluster. 65 """ 66 if cluster_resolver is None: 67 cluster_resolver = TPUClusterResolver("") 68 69 logging.info("Initializing the TPU system.") 70 71 if context.executing_eagerly(): 72 # This function looks as it is for the following non-intuitive reasons. 73 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger 74 # DistributedTPURewritePass. This pass actually adds real ops that 75 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system 76 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 77 # The easiest way to trigger a rewrite is to run the function with 78 # TPUPartitionedCallOp. 79 @function.defun 80 def _tpu_init_fn(): 81 return tpu.initialize_system() 82 83 # We can't call _tpu_init_fn normally (because it contains just a dummy op, 84 # see above) but need to define it to get it added to eager context 85 # and get its assigned name. 86 # pylint: disable=protected-access 87 graph_func = _tpu_init_fn._get_concrete_function_internal() 88 func_name = compat.as_str(graph_func._inference_function.name) 89 # pylint: enable=protected-access 90 91 with ops.device(get_first_tpu_host_device(cluster_resolver)): 92 output = tpu_functional_ops.TPUPartitionedCall( 93 args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) 94 serialized_topology = output[0].numpy() 95 else: 96 master = cluster_resolver.master() 97 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 98 with ops.Graph().as_default(): 99 with session_lib.Session(config=session_config, target=master) as sess: 100 serialized_topology = sess.run(tpu.initialize_system()) 101 102 logging.info("Finished initializing TPU system.") 103 return topology.Topology(serialized=serialized_topology) 104