1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""TPU specific APIs to be used in conjunction with TPU Strategy.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.protobuf import config_pb2 22from tensorflow.python.client import session as session_lib 23from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import device 27from tensorflow.python.framework import ops 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.tpu import topology 30from tensorflow.python.tpu import tpu 31from tensorflow.python.util import compat 32from tensorflow.python.util.tf_export import tf_export 33 34 35_INITIALIZED_TPU_SYSTEMS = {} 36_LOCAL_MASTERS = ("", "local") 37 38 39@tf_export("tpu.experimental.initialize_tpu_system") 40def initialize_tpu_system(cluster_resolver=None): 41 """Initialize the TPU devices. 42 43 Args: 44 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 45 which provides information about the TPU cluster. 46 Returns: 47 The tf.tpu.Topology object for the topology of the TPU cluster. 48 49 Raises: 50 RuntimeError: If no TPU devices found for eager execution or if run in a 51 tf.function. 52 """ 53 job = None 54 if cluster_resolver is None: 55 # If no cluster resolver is specified, and running eagerly, execute the init 56 # ops in the current device scope. 57 if context.executing_eagerly(): 58 curr_device = device.DeviceSpec.from_string(context.context().device_name) 59 if curr_device.job is not None: 60 job = "{}/replica:0/task:0".format(curr_device.job) 61 62 cluster_resolver = TPUClusterResolver("") 63 assert isinstance(cluster_resolver, TPUClusterResolver) 64 65 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 66 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 67 logging.warning( 68 "TPU system %s has already been initialized. " 69 "Reinitializing the TPU can cause previously created " 70 "variables on TPU to be lost.", tpu_name) 71 72 logging.info("Initializing the TPU system: %s", tpu_name) 73 74 if context.executing_eagerly(): 75 # This function looks as it is for the following non-intuitive reasons. 76 # tpu.initialize_system creates a dummy op whose sole purpose is to trigger 77 # DistributedTPURewritePass. This pass actually adds real ops that 78 # initialize the TPU system. Thus, we can't simply run tpu.initialize_system 79 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 80 if tpu_name not in _LOCAL_MASTERS: 81 # Explicitly place the tpu.initialize_system in the first worker to 82 # avoid the output node match multiple devices error. 83 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 84 85 @function.defun 86 def _tpu_init_fn(): 87 # In TF1, we usually close chips when compilation fails to clear the data 88 # in infeed. In TF2, we don't need to do this because infeed is no longer 89 # used, so user can recover from TPU compilation failures more smoothly. 90 return tpu.initialize_system( 91 job=job, compilation_failure_closes_chips=False) 92 93 # The TPU_SYSTEM device must match the device used in tpu.initialize_system 94 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 95 # devices available. 96 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 97 output = _tpu_init_fn() 98 99 # Clear out the eager context caches since the memory is invalid now. 100 logging.info("Clearing out eager caches") 101 context.context()._clear_caches() # pylint: disable=protected-access 102 103 serialized_topology = output.numpy() 104 105 # TODO(b/134094971): Remove this when lazy tensor copy in multi-device 106 # function has been implemented. 107 context.context().mirroring_policy = context.MIRRORING_ALL 108 elif not ops.executing_eagerly_outside_functions(): 109 master = cluster_resolver.master() 110 cluster_spec = cluster_resolver.cluster_spec() 111 112 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 113 if cluster_spec: 114 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 115 116 with ops.Graph().as_default(): 117 with session_lib.Session(config=session_config, target=master) as sess: 118 serialized_topology = sess.run(tpu.initialize_system()) 119 else: 120 raise RuntimeError("initialize_tpu_system is not supported within " 121 "tf.functions.") 122 123 logging.info("Finished initializing TPU system.") 124 tpu_topology = topology.Topology(serialized=serialized_topology) 125 _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology 126 127 return tpu_topology 128 129 130@tf_export("tpu.experimental.shutdown_tpu_system") 131def shutdown_tpu_system(cluster_resolver=None): 132 """Shuts down the TPU devices. 133 134 This will clear all caches, even those that are maintained through sequential 135 calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation 136 cache. 137 138 Args: 139 cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, 140 which provides information about the TPU cluster. 141 142 Raises: 143 RuntimeError: If no TPU devices found for eager execution or if run in a 144 tf.function. 145 """ 146 job = None 147 if cluster_resolver is None: 148 # If no cluster resolver is specified, and running eagerly, execute the init 149 # ops in the current device scope. 150 if context.executing_eagerly(): 151 curr_device = device.DeviceSpec.from_string(context.context().device_name) 152 if curr_device.job is not None: 153 job = "{}/replica:0/task:0".format(curr_device.job) 154 155 cluster_resolver = TPUClusterResolver("") 156 assert isinstance(cluster_resolver, TPUClusterResolver) 157 158 tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access 159 if tpu_name not in _INITIALIZED_TPU_SYSTEMS: 160 logging.warning("You are shutting down a TPU system %s that has not been " 161 "initialized." % tpu_name) 162 163 logging.info("Shutting down the TPU system: %s", tpu_name) 164 165 if context.executing_eagerly(): 166 # This function looks as it is for the following non-intuitive reasons. 167 # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger 168 # DistributedTPURewritePass. This pass actually adds real ops that 169 # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system 170 # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. 171 if tpu_name not in _LOCAL_MASTERS: 172 # Explicitly place the tpu.shutdown_system in the first worker to 173 # avoid the output node match multiple devices error. 174 job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) 175 176 @function.defun 177 def _tpu_shutdown_fn(): 178 tpu.shutdown_system(job=job) 179 180 # The TPU_SYSTEM device must match the device used in tpu.shutdown_system 181 # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM 182 # devices available. 183 with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access 184 _tpu_shutdown_fn() 185 186 # Clear out the eager context caches since the memory is invalid now. 187 logging.info("Clearing out eager caches") 188 context.context()._clear_caches() # pylint: disable=protected-access 189 elif not ops.executing_eagerly_outside_functions(): 190 master = cluster_resolver.master() 191 cluster_spec = cluster_resolver.cluster_spec() 192 193 session_config = config_pb2.ConfigProto(allow_soft_placement=True) 194 if cluster_spec: 195 session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 196 197 with ops.Graph().as_default(): 198 with session_lib.Session(config=session_config, target=master) as sess: 199 sess.run(tpu.shutdown_system()) 200 else: 201 raise RuntimeError("initialize_tpu_system is not supported within " 202 "tf.functions.") 203 204 logging.info("Finished shutting down TPU system.") 205 if tpu_name in _INITIALIZED_TPU_SYSTEMS: 206 del _INITIALIZED_TPU_SYSTEMS[tpu_name] 207