1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU specific APIs to be used in conjunction with TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.client import session as session_lib 23from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import device 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.tpu import topology 31from tensorflow.python.tpu import tpu 32from tensorflow.python.util import compat 33from tensorflow.python.util.tf_export import tf_export 34 35 36_INITIALIZED_TPU_SYSTEMS = {} 37_LOCAL_MASTERS = ("", "local") 38 39 40@tf_export("tpu.experimental.initialize_tpu_system") 41def initialize_tpu_system(cluster_resolver=None): 42 """Initialize the TPU devices. 43 44 Args: 45 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 46 which provides information about the TPU cluster. 47 Returns: 48 The tf.tpu.Topology object for the topology of the TPU cluster. If called 49 inside tf.function, it returns the serialized topology object instead. 50 51 Raises: 52 RuntimeError: If running inside a tf.function. 53 NotFoundError: If no TPU devices found in eager mode. 54 """ 55 job = None 56 if cluster_resolver is None: 57 # If no cluster resolver is specified, and running eagerly, execute the init 58 # ops in the current device scope. 59 if context.executing_eagerly(): 60 curr_device = device.DeviceSpec.from_string(context.context().device_name) 61 if curr_device.job is not None: 62 job = "{}/replica:0/task:0".format(curr_device.job) 63 64 cluster_resolver = TPUClusterResolver("") 65 assert isinstance(cluster_resolver, TPUClusterResolver) 66 67 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 68 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 69 logging.warning( 70 "TPU system %s has already been initialized. " 71 "Reinitializing the TPU can cause previously created " 72 "variables on TPU to be lost.", tpu_name) 73 74 logging.info("Initializing the TPU system: %s", tpu_name) 75 76 # This function looks as it is for the following non-intuitive reasons. 77 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger 78 # DistributedTPURewritePass. This pass actually adds real ops that 79 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system 80 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 81 if tpu_name not in _LOCAL_MASTERS: 82 # Explicitly place the tpu.initialize_system in the first worker to 83 # avoid the output node match multiple devices error. 84 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 85 86 if context.executing_eagerly(): 87 @function.defun 88 def _tpu_init_fn(): 89 # In TF1, we usually close chips when compilation fails to clear the data 90 # in infeed. In TF2, we don't need to do this because infeed is no longer 91 # used, so user can recover from TPU compilation failures more smoothly. 92 return tpu.initialize_system( 93 job=job, compilation_failure_closes_chips=False) 94 95 # The TPU_SYSTEM device must match the device used in tpu.initialize_system 96 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 97 # devices available. 98 try: 99 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 100 output = _tpu_init_fn() 101 context.async_wait() 102 except errors.InvalidArgumentError as e: 103 raise errors.NotFoundError( 104 None, None, 105 "TPUs not found in the cluster. Failed in initialization: " 106 + str(e)) 107 108 # Clear out the eager context caches since the memory is invalid now. 109 logging.info("Clearing out eager caches") 110 context.context()._clear_caches() # pylint: disable=protected-access 111 context.context()._initialize_logical_devices() # pylint: disable=protected-access 112 context.context().clear_kernel_cache() 113 114 serialized_topology = output.numpy() 115 elif not ops.executing_eagerly_outside_functions(): 116 master = cluster_resolver.master() 117 cluster_spec = cluster_resolver.cluster_spec() 118 119 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 120 if cluster_spec: 121 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 122 123 with ops.Graph().as_default(): 124 with session_lib.Session(config=session_config, target=master) as sess: 125 serialized_topology = sess.run(tpu.initialize_system()) 126 else: 127 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 128 serialized_topology = tpu.initialize_system( 129 job=job, compilation_failure_closes_chips=False) 130 # If initialize_tpu_system is called inside tf.function, we only return 131 # the serialized topology object as the tf.tpu.Topology object has to be 132 # constructed in eager mode. 133 return serialized_topology 134 135 logging.info("Finished initializing TPU system.") 136 tpu_topology = topology.Topology(serialized=serialized_topology) 137 _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology 138 139 return tpu_topology 140 141 142@tf_export("tpu.experimental.shutdown_tpu_system") 143def shutdown_tpu_system(cluster_resolver=None): 144 """Shuts down the TPU devices. 145 146 This will clear all caches, even those that are maintained through sequential 147 calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation 148 cache. 149 150 Args: 151 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 152 which provides information about the TPU cluster. 153 154 Raises: 155 RuntimeError: If no TPU devices found for eager execution or if run in a 156 tf.function. 157 """ 158 job = None 159 if cluster_resolver is None: 160 # If no cluster resolver is specified, and running eagerly, execute the init 161 # ops in the current device scope. 162 if context.executing_eagerly(): 163 curr_device = device.DeviceSpec.from_string(context.context().device_name) 164 if curr_device.job is not None: 165 job = "{}/replica:0/task:0".format(curr_device.job) 166 167 cluster_resolver = TPUClusterResolver("") 168 assert isinstance(cluster_resolver, TPUClusterResolver) 169 170 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 171 if tpu_name not in _INITIALIZED_TPU_SYSTEMS: 172 logging.warning("You are shutting down a TPU system %s that has not been " 173 "initialized." % tpu_name) 174 175 logging.info("Shutting down the TPU system: %s", tpu_name) 176 177 if context.executing_eagerly(): 178 # This function looks as it is for the following non-intuitive reasons. 179 # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger 180 # DistributedTPURewritePass. This pass actually adds real ops that 181 # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system 182 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 183 if tpu_name not in _LOCAL_MASTERS: 184 # Explicitly place the tpu.shutdown_system in the first worker to 185 # avoid the output node match multiple devices error. 186 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 187 188 @function.defun 189 def _tpu_shutdown_fn(): 190 tpu.shutdown_system(job=job) 191 192 # The TPU_SYSTEM device must match the device used in tpu.shutdown_system 193 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 194 # devices available. 195 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 196 _tpu_shutdown_fn() 197 198 # Clear out the eager context caches since the memory is invalid now. 199 logging.info("Clearing out eager caches") 200 context.context()._clear_caches() # pylint: disable=protected-access 201 context.context().clear_kernel_cache() 202 elif not ops.executing_eagerly_outside_functions(): 203 master = cluster_resolver.master() 204 cluster_spec = cluster_resolver.cluster_spec() 205 206 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 207 if cluster_spec: 208 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 209 210 with ops.Graph().as_default(): 211 with session_lib.Session(config=session_config, target=master) as sess: 212 sess.run(tpu.shutdown_system()) 213 else: 214 raise RuntimeError("initialize_tpu_system is not supported within " 215 "tf.functions.") 216 217 logging.info("Finished shutting down TPU system.") 218 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 219 del _INITIALIZED_TPU_SYSTEMS[tpu_name] 220