1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU specific APIs to be used in conjunction with TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import gc 22 23from tensorflow.core.protobuf import config_pb2 24from tensorflow.python.client import session as session_lib 25from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver 26from tensorflow.python.eager import context 27from tensorflow.python.eager import function 28from tensorflow.python.framework import device 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import ops 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.tpu import topology 33from tensorflow.python.tpu import tpu 34from tensorflow.python.util import compat 35from tensorflow.python.util.tf_export import tf_export 36 37 38_INITIALIZED_TPU_SYSTEMS = {} 39_LOCAL_MASTERS = ("", "local") 40 41 42@tf_export("tpu.experimental.initialize_tpu_system") 43def initialize_tpu_system(cluster_resolver=None): 44 """Initialize the TPU devices. 45 46 Args: 47 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 48 which provides information about the TPU cluster. 49 Returns: 50 The tf.tpu.Topology object for the topology of the TPU cluster. If called 51 inside tf.function, it returns the serialized topology object instead. 52 53 Raises: 54 RuntimeError: If running inside a tf.function. 55 NotFoundError: If no TPU devices found in eager mode. 56 """ 57 58 # Deallocate all TPU buffers by clearing out eager context caches and 59 # triggering garbage collection to avoid keeping invalid tpu buffer around 60 # after reinitialized tpu system. 61 logging.info("Deallocate tpu buffers before initializing tpu system.") 62 context.context()._clear_caches() # pylint: disable=protected-access 63 context.context().clear_kernel_cache() 64 gc.collect() 65 66 job = None 67 if cluster_resolver is None: 68 # If no cluster resolver is specified, and running eagerly, execute the init 69 # ops in the current device scope. 70 if context.executing_eagerly(): 71 curr_device = device.DeviceSpec.from_string(context.context().device_name) 72 if curr_device.job is not None: 73 job = "{}/replica:0/task:0".format(curr_device.job) 74 75 cluster_resolver = TPUClusterResolver("") 76 assert isinstance(cluster_resolver, TPUClusterResolver) 77 78 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 79 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 80 logging.warning( 81 "TPU system %s has already been initialized. " 82 "Reinitializing the TPU can cause previously created " 83 "variables on TPU to be lost.", tpu_name) 84 85 logging.info("Initializing the TPU system: %s", tpu_name) 86 87 # This function looks as it is for the following non-intuitive reasons. 88 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger 89 # DistributedTPURewritePass. This pass actually adds real ops that 90 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system 91 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 92 if tpu_name not in _LOCAL_MASTERS: 93 # Explicitly place the tpu.initialize_system in the first worker to 94 # avoid the output node match multiple devices error. 95 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 96 97 if context.executing_eagerly(): 98 @function.defun 99 def _tpu_init_fn(): 100 # In TF1, we usually close chips when compilation fails to clear the data 101 # in infeed. In TF2, we don't need to do this because infeed is no longer 102 # used, so user can recover from TPU compilation failures more smoothly. 103 return tpu.initialize_system( 104 job=job, compilation_failure_closes_chips=False) 105 106 # The TPU_SYSTEM device must match the device used in tpu.initialize_system 107 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 108 # devices available. 109 try: 110 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 111 output = _tpu_init_fn() 112 context.async_wait() 113 except errors.InvalidArgumentError as e: 114 raise errors.NotFoundError( 115 None, None, 116 "TPUs not found in the cluster. Failed in initialization: " 117 + str(e)) 118 119 # Clear out the eager context caches since the memory is invalid now. 120 context.context()._initialize_logical_devices() # pylint: disable=protected-access 121 122 serialized_topology = output.numpy() 123 elif not ops.executing_eagerly_outside_functions(): 124 master = cluster_resolver.master() 125 cluster_spec = cluster_resolver.cluster_spec() 126 127 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 128 if cluster_spec: 129 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 130 131 with ops.Graph().as_default(): 132 with session_lib.Session(config=session_config, target=master) as sess: 133 serialized_topology = sess.run(tpu.initialize_system()) 134 else: 135 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 136 serialized_topology = tpu.initialize_system( 137 job=job, compilation_failure_closes_chips=False) 138 # If initialize_tpu_system is called inside tf.function, we only return 139 # the serialized topology object as the tf.tpu.Topology object has to be 140 # constructed in eager mode. 141 return serialized_topology 142 143 logging.info("Finished initializing TPU system.") 144 tpu_topology = topology.Topology(serialized=serialized_topology) 145 _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology 146 147 return tpu_topology 148 149 150def get_initialized_tpu_systems(): 151 """Returns all currently initialized tpu systems. 152 153 Returns: 154 A dictionary, with tpu name as the key and the tpu topology as the value. 155 """ 156 return _INITIALIZED_TPU_SYSTEMS.copy() 157 158 159@tf_export("tpu.experimental.shutdown_tpu_system") 160def shutdown_tpu_system(cluster_resolver=None): 161 """Shuts down the TPU devices. 162 163 This will clear all caches, even those that are maintained through sequential 164 calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation 165 cache. 166 167 Args: 168 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 169 which provides information about the TPU cluster. 170 171 Raises: 172 RuntimeError: If no TPU devices found for eager execution or if run in a 173 tf.function. 174 """ 175 job = None 176 if cluster_resolver is None: 177 # If no cluster resolver is specified, and running eagerly, execute the init 178 # ops in the current device scope. 179 if context.executing_eagerly(): 180 curr_device = device.DeviceSpec.from_string(context.context().device_name) 181 if curr_device.job is not None: 182 job = "{}/replica:0/task:0".format(curr_device.job) 183 184 cluster_resolver = TPUClusterResolver("") 185 assert isinstance(cluster_resolver, TPUClusterResolver) 186 187 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 188 if tpu_name not in _INITIALIZED_TPU_SYSTEMS: 189 logging.warning("You are shutting down a TPU system %s that has not been " 190 "initialized." % tpu_name) 191 192 logging.info("Shutting down the TPU system: %s", tpu_name) 193 194 if context.executing_eagerly(): 195 # This function looks as it is for the following non-intuitive reasons. 196 # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger 197 # DistributedTPURewritePass. This pass actually adds real ops that 198 # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system 199 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 200 if tpu_name not in _LOCAL_MASTERS: 201 # Explicitly place the tpu.shutdown_system in the first worker to 202 # avoid the output node match multiple devices error. 203 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 204 205 @function.defun 206 def _tpu_shutdown_fn(): 207 tpu.shutdown_system(job=job) 208 209 # The TPU_SYSTEM device must match the device used in tpu.shutdown_system 210 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 211 # devices available. 212 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 213 _tpu_shutdown_fn() 214 215 # Clear out the eager context caches since the memory is invalid now. 216 logging.info("Clearing out eager caches") 217 context.context()._clear_caches() # pylint: disable=protected-access 218 context.context().clear_kernel_cache() 219 elif not ops.executing_eagerly_outside_functions(): 220 master = cluster_resolver.master() 221 cluster_spec = cluster_resolver.cluster_spec() 222 223 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 224 if cluster_spec: 225 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 226 227 with ops.Graph().as_default(): 228 with session_lib.Session(config=session_config, target=master) as sess: 229 sess.run(tpu.shutdown_system()) 230 else: 231 raise RuntimeError( 232 "initialize_tpu_system is not supported within " 233 "tf.functions. You should call initialize_tpu_system outside of your tf.function. " 234 ) 235 236 logging.info("Finished shutting down TPU system.") 237 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 238 del _INITIALIZED_TPU_SYSTEMS[tpu_name] 239