• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU specific APIs to be used in conjunction with TPU Strategy."""
16
17import gc
18
19import gc
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.client import session as session_lib
23from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import device
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.tpu import topology
31from tensorflow.python.tpu import tpu
32from tensorflow.python.util import compat
33from tensorflow.python.util.tf_export import tf_export
34
35
36_INITIALIZED_TPU_SYSTEMS = {}
37_LOCAL_MASTERS = ("", "local")
38
39
40@tf_export("tpu.experimental.initialize_tpu_system")
41def initialize_tpu_system(cluster_resolver=None):
42  """Initialize the TPU devices.
43
44  Args:
45    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
46        which provides information about the TPU cluster.
47  Returns:
48    The tf.tpu.Topology object for the topology of the TPU cluster. If called
49    inside tf.function, it returns the serialized topology object instead.
50
51  Raises:
52    RuntimeError: If running inside a tf.function.
53    NotFoundError: If no TPU devices found in eager mode.
54  """
55
56  # Deallocate all TPU buffers by clearing out eager context caches and
57  # triggering garbage collection to avoid keeping invalid tpu buffer around
58  # after reinitialized tpu system.
59  logging.info("Deallocate tpu buffers before initializing tpu system.")
60  context.context()._clear_caches()  # pylint: disable=protected-access
61  context.context().clear_kernel_cache()
62  gc.collect()
63
64  job = None
65  if cluster_resolver is None:
66    # If no cluster resolver is specified, and running eagerly, execute the init
67    # ops in the current device scope.
68    if context.executing_eagerly():
69      curr_device = device.DeviceSpec.from_string(context.context().device_name)
70      if curr_device.job is not None:
71        job = "{}/replica:0/task:0".format(curr_device.job)
72
73    cluster_resolver = TPUClusterResolver("")
74  assert isinstance(cluster_resolver, TPUClusterResolver)
75
76  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
77  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
78    logging.warning(
79        "TPU system %s has already been initialized. "
80        "Reinitializing the TPU can cause previously created "
81        "variables on TPU to be lost.", tpu_name)
82
83  logging.info("Initializing the TPU system: %s", tpu_name)
84
85  # This function looks as it is for the following non-intuitive reasons.
86  # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
87  # DistributedTPURewritePass. This pass actually adds real ops that
88  # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
89  # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
90  if tpu_name not in _LOCAL_MASTERS:
91    # Explicitly place the tpu.initialize_system in the first worker to
92    # avoid the output node match multiple devices error.
93    job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
94
95  if context.executing_eagerly():
96    @function.defun
97    def _tpu_init_fn():
98      # In TF1, we usually close chips when compilation fails to clear the data
99      # in infeed. In TF2, we don't need to do this because infeed is no longer
100      # used, so user can recover from TPU compilation failures more smoothly.
101      # Same for the cancellation of a TPU excution.
102      return tpu.initialize_system(
103          job=job,
104          compilation_failure_closes_chips=False,
105          tpu_cancellation_closes_chips=False)
106
107    # The TPU_SYSTEM device must match the device used in tpu.initialize_system
108    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
109    # devices available.
110    try:
111      with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
112        output = _tpu_init_fn()
113      context.async_wait()
114    except errors.InvalidArgumentError as e:
115      raise errors.NotFoundError(
116          None, None,
117          "TPUs not found in the cluster. Failed in initialization: "
118          + str(e))
119
120    # Clear out the eager context caches since the memory is invalid now.
121    context.context()._initialize_logical_devices()  # pylint: disable=protected-access
122
123    serialized_topology = output.numpy()
124  elif not ops.executing_eagerly_outside_functions():
125    master = cluster_resolver.master()
126    cluster_spec = cluster_resolver.cluster_spec()
127
128    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
129    if cluster_spec:
130      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
131
132    with ops.Graph().as_default():
133      with session_lib.Session(config=session_config, target=master) as sess:
134        serialized_topology = sess.run(tpu.initialize_system())
135  else:
136    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
137      serialized_topology = tpu.initialize_system(
138          job=job, compilation_failure_closes_chips=False)
139      # If initialize_tpu_system is called inside tf.function, we only return
140      # the serialized topology object as the tf.tpu.Topology object has to be
141      # constructed in eager mode.
142      return serialized_topology
143
144  logging.info("Finished initializing TPU system.")
145  tpu_topology = topology.Topology(serialized=serialized_topology)
146  cluster_resolver.set_tpu_topology(serialized_topology)
147  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
148
149  return tpu_topology
150
151
152def get_initialized_tpu_systems():
153  """Returns all currently initialized tpu systems.
154
155  Returns:
156     A dictionary, with tpu name as the key and the tpu topology as the value.
157  """
158  return _INITIALIZED_TPU_SYSTEMS.copy()
159
160
161@tf_export("tpu.experimental.shutdown_tpu_system")
162def shutdown_tpu_system(cluster_resolver=None):
163  """Shuts down the TPU devices.
164
165  This will clear all caches, even those that are maintained through sequential
166  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
167  cache.
168
169  Args:
170    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
171        which provides information about the TPU cluster.
172
173  Raises:
174    RuntimeError: If no TPU devices found for eager execution or if run in a
175        tf.function.
176  """
177  job = None
178  if cluster_resolver is None:
179    # If no cluster resolver is specified, and running eagerly, execute the init
180    # ops in the current device scope.
181    if context.executing_eagerly():
182      curr_device = device.DeviceSpec.from_string(context.context().device_name)
183      if curr_device.job is not None:
184        job = "{}/replica:0/task:0".format(curr_device.job)
185
186    cluster_resolver = TPUClusterResolver("")
187  assert isinstance(cluster_resolver, TPUClusterResolver)
188
189  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
190  if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
191    logging.warning("You are shutting down a TPU system %s that has not been "
192                    "initialized." % tpu_name)
193
194  logging.info("Shutting down the TPU system: %s", tpu_name)
195
196  if context.executing_eagerly():
197    # This function looks as it is for the following non-intuitive reasons.
198    # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
199    # DistributedTPURewritePass. This pass actually adds real ops that
200    # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
201    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
202    if tpu_name not in _LOCAL_MASTERS:
203      # Explicitly place the tpu.shutdown_system in the first worker to
204      # avoid the output node match multiple devices error.
205      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
206
207    @function.defun
208    def _tpu_shutdown_fn():
209      tpu.shutdown_system(job=job)
210
211    # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
212    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
213    # devices available.
214    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
215      _tpu_shutdown_fn()
216
217    # Clear out the eager context caches since the memory is invalid now.
218    logging.info("Clearing out eager caches")
219    context.context()._clear_caches()  # pylint: disable=protected-access
220    context.context().clear_kernel_cache()
221  elif not ops.executing_eagerly_outside_functions():
222    master = cluster_resolver.master()
223    cluster_spec = cluster_resolver.cluster_spec()
224
225    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
226    if cluster_spec:
227      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
228
229    with ops.Graph().as_default():
230      with session_lib.Session(config=session_config, target=master) as sess:
231        sess.run(tpu.shutdown_system())
232  else:
233    raise RuntimeError(
234        "initialize_tpu_system is not supported within "
235        "tf.functions.  You should call initialize_tpu_system outside of your tf.function. "
236    )
237
238  logging.info("Finished shutting down TPU system.")
239  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
240    del _INITIALIZED_TPU_SYSTEMS[tpu_name]
241