1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU specific APIs to be used in conjunction with TPU Strategy.""" 16 17import gc 18 19import gc 20 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.client import session as session_lib 23from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import device 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.tpu import topology 31from tensorflow.python.tpu import tpu 32from tensorflow.python.util import compat 33from tensorflow.python.util.tf_export import tf_export 34 35 36_INITIALIZED_TPU_SYSTEMS = {} 37_LOCAL_MASTERS = ("", "local") 38 39 40@tf_export("tpu.experimental.initialize_tpu_system") 41def initialize_tpu_system(cluster_resolver=None): 42 """Initialize the TPU devices. 43 44 Args: 45 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 46 which provides information about the TPU cluster. 47 Returns: 48 The tf.tpu.Topology object for the topology of the TPU cluster. If called 49 inside tf.function, it returns the serialized topology object instead. 50 51 Raises: 52 RuntimeError: If running inside a tf.function. 53 NotFoundError: If no TPU devices found in eager mode. 54 """ 55 56 # Deallocate all TPU buffers by clearing out eager context caches and 57 # triggering garbage collection to avoid keeping invalid tpu buffer around 58 # after reinitialized tpu system. 59 logging.info("Deallocate tpu buffers before initializing tpu system.") 60 context.context()._clear_caches() # pylint: disable=protected-access 61 context.context().clear_kernel_cache() 62 gc.collect() 63 64 job = None 65 if cluster_resolver is None: 66 # If no cluster resolver is specified, and running eagerly, execute the init 67 # ops in the current device scope. 68 if context.executing_eagerly(): 69 curr_device = device.DeviceSpec.from_string(context.context().device_name) 70 if curr_device.job is not None: 71 job = "{}/replica:0/task:0".format(curr_device.job) 72 73 cluster_resolver = TPUClusterResolver("") 74 assert isinstance(cluster_resolver, TPUClusterResolver) 75 76 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 77 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 78 logging.warning( 79 "TPU system %s has already been initialized. " 80 "Reinitializing the TPU can cause previously created " 81 "variables on TPU to be lost.", tpu_name) 82 83 logging.info("Initializing the TPU system: %s", tpu_name) 84 85 # This function looks as it is for the following non-intuitive reasons. 86 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger 87 # DistributedTPURewritePass. This pass actually adds real ops that 88 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system 89 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 90 if tpu_name not in _LOCAL_MASTERS: 91 # Explicitly place the tpu.initialize_system in the first worker to 92 # avoid the output node match multiple devices error. 93 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 94 95 if context.executing_eagerly(): 96 @function.defun 97 def _tpu_init_fn(): 98 # In TF1, we usually close chips when compilation fails to clear the data 99 # in infeed. In TF2, we don't need to do this because infeed is no longer 100 # used, so user can recover from TPU compilation failures more smoothly. 101 # Same for the cancellation of a TPU excution. 102 return tpu.initialize_system( 103 job=job, 104 compilation_failure_closes_chips=False, 105 tpu_cancellation_closes_chips=False) 106 107 # The TPU_SYSTEM device must match the device used in tpu.initialize_system 108 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 109 # devices available. 110 try: 111 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 112 output = _tpu_init_fn() 113 context.async_wait() 114 except errors.InvalidArgumentError as e: 115 raise errors.NotFoundError( 116 None, None, 117 "TPUs not found in the cluster. Failed in initialization: " 118 + str(e)) 119 120 # Clear out the eager context caches since the memory is invalid now. 121 context.context()._initialize_logical_devices() # pylint: disable=protected-access 122 123 serialized_topology = output.numpy() 124 elif not ops.executing_eagerly_outside_functions(): 125 master = cluster_resolver.master() 126 cluster_spec = cluster_resolver.cluster_spec() 127 128 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 129 if cluster_spec: 130 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 131 132 with ops.Graph().as_default(): 133 with session_lib.Session(config=session_config, target=master) as sess: 134 serialized_topology = sess.run(tpu.initialize_system()) 135 else: 136 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 137 serialized_topology = tpu.initialize_system( 138 job=job, compilation_failure_closes_chips=False) 139 # If initialize_tpu_system is called inside tf.function, we only return 140 # the serialized topology object as the tf.tpu.Topology object has to be 141 # constructed in eager mode. 142 return serialized_topology 143 144 logging.info("Finished initializing TPU system.") 145 tpu_topology = topology.Topology(serialized=serialized_topology) 146 cluster_resolver.set_tpu_topology(serialized_topology) 147 _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology 148 149 return tpu_topology 150 151 152def get_initialized_tpu_systems(): 153 """Returns all currently initialized tpu systems. 154 155 Returns: 156 A dictionary, with tpu name as the key and the tpu topology as the value. 157 """ 158 return _INITIALIZED_TPU_SYSTEMS.copy() 159 160 161@tf_export("tpu.experimental.shutdown_tpu_system") 162def shutdown_tpu_system(cluster_resolver=None): 163 """Shuts down the TPU devices. 164 165 This will clear all caches, even those that are maintained through sequential 166 calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation 167 cache. 168 169 Args: 170 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 171 which provides information about the TPU cluster. 172 173 Raises: 174 RuntimeError: If no TPU devices found for eager execution or if run in a 175 tf.function. 176 """ 177 job = None 178 if cluster_resolver is None: 179 # If no cluster resolver is specified, and running eagerly, execute the init 180 # ops in the current device scope. 181 if context.executing_eagerly(): 182 curr_device = device.DeviceSpec.from_string(context.context().device_name) 183 if curr_device.job is not None: 184 job = "{}/replica:0/task:0".format(curr_device.job) 185 186 cluster_resolver = TPUClusterResolver("") 187 assert isinstance(cluster_resolver, TPUClusterResolver) 188 189 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 190 if tpu_name not in _INITIALIZED_TPU_SYSTEMS: 191 logging.warning("You are shutting down a TPU system %s that has not been " 192 "initialized." % tpu_name) 193 194 logging.info("Shutting down the TPU system: %s", tpu_name) 195 196 if context.executing_eagerly(): 197 # This function looks as it is for the following non-intuitive reasons. 198 # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger 199 # DistributedTPURewritePass. This pass actually adds real ops that 200 # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system 201 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 202 if tpu_name not in _LOCAL_MASTERS: 203 # Explicitly place the tpu.shutdown_system in the first worker to 204 # avoid the output node match multiple devices error. 205 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 206 207 @function.defun 208 def _tpu_shutdown_fn(): 209 tpu.shutdown_system(job=job) 210 211 # The TPU_SYSTEM device must match the device used in tpu.shutdown_system 212 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 213 # devices available. 214 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 215 _tpu_shutdown_fn() 216 217 # Clear out the eager context caches since the memory is invalid now. 218 logging.info("Clearing out eager caches") 219 context.context()._clear_caches() # pylint: disable=protected-access 220 context.context().clear_kernel_cache() 221 elif not ops.executing_eagerly_outside_functions(): 222 master = cluster_resolver.master() 223 cluster_spec = cluster_resolver.cluster_spec() 224 225 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 226 if cluster_spec: 227 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 228 229 with ops.Graph().as_default(): 230 with session_lib.Session(config=session_config, target=master) as sess: 231 sess.run(tpu.shutdown_system()) 232 else: 233 raise RuntimeError( 234 "initialize_tpu_system is not supported within " 235 "tf.functions. You should call initialize_tpu_system outside of your tf.function. " 236 ) 237 238 logging.info("Finished shutting down TPU system.") 239 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 240 del _INITIALIZED_TPU_SYSTEMS[tpu_name] 241