• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU specific APIs to be used in conjunction with TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import gc
22
23from tensorflow.core.protobuf import config_pb2
24from tensorflow.python.client import session as session_lib
25from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
26from tensorflow.python.eager import context
27from tensorflow.python.eager import function
28from tensorflow.python.framework import device
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import ops
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.tpu import topology
33from tensorflow.python.tpu import tpu
34from tensorflow.python.util import compat
35from tensorflow.python.util.tf_export import tf_export
36
37
38_INITIALIZED_TPU_SYSTEMS = {}
39_LOCAL_MASTERS = ("", "local")
40
41
42@tf_export("tpu.experimental.initialize_tpu_system")
43def initialize_tpu_system(cluster_resolver=None):
44  """Initialize the TPU devices.
45
46  Args:
47    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
48        which provides information about the TPU cluster.
49  Returns:
50    The tf.tpu.Topology object for the topology of the TPU cluster. If called
51    inside tf.function, it returns the serialized topology object instead.
52
53  Raises:
54    RuntimeError: If running inside a tf.function.
55    NotFoundError: If no TPU devices found in eager mode.
56  """
57
58  # Deallocate all TPU buffers by clearing out eager context caches and
59  # triggering garbage collection to avoid keeping invalid tpu buffer around
60  # after reinitialized tpu system.
61  logging.info("Deallocate tpu buffers before initializing tpu system.")
62  context.context()._clear_caches()  # pylint: disable=protected-access
63  context.context().clear_kernel_cache()
64  gc.collect()
65
66  job = None
67  if cluster_resolver is None:
68    # If no cluster resolver is specified, and running eagerly, execute the init
69    # ops in the current device scope.
70    if context.executing_eagerly():
71      curr_device = device.DeviceSpec.from_string(context.context().device_name)
72      if curr_device.job is not None:
73        job = "{}/replica:0/task:0".format(curr_device.job)
74
75    cluster_resolver = TPUClusterResolver("")
76  assert isinstance(cluster_resolver, TPUClusterResolver)
77
78  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
79  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
80    logging.warning(
81        "TPU system %s has already been initialized. "
82        "Reinitializing the TPU can cause previously created "
83        "variables on TPU to be lost.", tpu_name)
84
85  logging.info("Initializing the TPU system: %s", tpu_name)
86
87  # This function looks as it is for the following non-intuitive reasons.
88  # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
89  # DistributedTPURewritePass. This pass actually adds real ops that
90  # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
91  # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
92  if tpu_name not in _LOCAL_MASTERS:
93    # Explicitly place the tpu.initialize_system in the first worker to
94    # avoid the output node match multiple devices error.
95    job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
96
97  if context.executing_eagerly():
98    @function.defun
99    def _tpu_init_fn():
100      # In TF1, we usually close chips when compilation fails to clear the data
101      # in infeed. In TF2, we don't need to do this because infeed is no longer
102      # used, so user can recover from TPU compilation failures more smoothly.
103      return tpu.initialize_system(
104          job=job, compilation_failure_closes_chips=False)
105
106    # The TPU_SYSTEM device must match the device used in tpu.initialize_system
107    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
108    # devices available.
109    try:
110      with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
111        output = _tpu_init_fn()
112      context.async_wait()
113    except errors.InvalidArgumentError as e:
114      raise errors.NotFoundError(
115          None, None,
116          "TPUs not found in the cluster. Failed in initialization: "
117          + str(e))
118
119    # Clear out the eager context caches since the memory is invalid now.
120    context.context()._initialize_logical_devices()  # pylint: disable=protected-access
121
122    serialized_topology = output.numpy()
123  elif not ops.executing_eagerly_outside_functions():
124    master = cluster_resolver.master()
125    cluster_spec = cluster_resolver.cluster_spec()
126
127    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
128    if cluster_spec:
129      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
130
131    with ops.Graph().as_default():
132      with session_lib.Session(config=session_config, target=master) as sess:
133        serialized_topology = sess.run(tpu.initialize_system())
134  else:
135    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
136      serialized_topology = tpu.initialize_system(
137          job=job, compilation_failure_closes_chips=False)
138      # If initialize_tpu_system is called inside tf.function, we only return
139      # the serialized topology object as the tf.tpu.Topology object has to be
140      # constructed in eager mode.
141      return serialized_topology
142
143  logging.info("Finished initializing TPU system.")
144  tpu_topology = topology.Topology(serialized=serialized_topology)
145  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
146
147  return tpu_topology
148
149
150def get_initialized_tpu_systems():
151  """Returns all currently initialized tpu systems.
152
153  Returns:
154     A dictionary, with tpu name as the key and the tpu topology as the value.
155  """
156  return _INITIALIZED_TPU_SYSTEMS.copy()
157
158
159@tf_export("tpu.experimental.shutdown_tpu_system")
160def shutdown_tpu_system(cluster_resolver=None):
161  """Shuts down the TPU devices.
162
163  This will clear all caches, even those that are maintained through sequential
164  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
165  cache.
166
167  Args:
168    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
169        which provides information about the TPU cluster.
170
171  Raises:
172    RuntimeError: If no TPU devices found for eager execution or if run in a
173        tf.function.
174  """
175  job = None
176  if cluster_resolver is None:
177    # If no cluster resolver is specified, and running eagerly, execute the init
178    # ops in the current device scope.
179    if context.executing_eagerly():
180      curr_device = device.DeviceSpec.from_string(context.context().device_name)
181      if curr_device.job is not None:
182        job = "{}/replica:0/task:0".format(curr_device.job)
183
184    cluster_resolver = TPUClusterResolver("")
185  assert isinstance(cluster_resolver, TPUClusterResolver)
186
187  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
188  if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
189    logging.warning("You are shutting down a TPU system %s that has not been "
190                    "initialized." % tpu_name)
191
192  logging.info("Shutting down the TPU system: %s", tpu_name)
193
194  if context.executing_eagerly():
195    # This function looks as it is for the following non-intuitive reasons.
196    # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
197    # DistributedTPURewritePass. This pass actually adds real ops that
198    # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
199    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
200    if tpu_name not in _LOCAL_MASTERS:
201      # Explicitly place the tpu.shutdown_system in the first worker to
202      # avoid the output node match multiple devices error.
203      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
204
205    @function.defun
206    def _tpu_shutdown_fn():
207      tpu.shutdown_system(job=job)
208
209    # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
210    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
211    # devices available.
212    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
213      _tpu_shutdown_fn()
214
215    # Clear out the eager context caches since the memory is invalid now.
216    logging.info("Clearing out eager caches")
217    context.context()._clear_caches()  # pylint: disable=protected-access
218    context.context().clear_kernel_cache()
219  elif not ops.executing_eagerly_outside_functions():
220    master = cluster_resolver.master()
221    cluster_spec = cluster_resolver.cluster_spec()
222
223    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
224    if cluster_spec:
225      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
226
227    with ops.Graph().as_default():
228      with session_lib.Session(config=session_config, target=master) as sess:
229        sess.run(tpu.shutdown_system())
230  else:
231    raise RuntimeError(
232        "initialize_tpu_system is not supported within "
233        "tf.functions.  You should call initialize_tpu_system outside of your tf.function. "
234    )
235
236  logging.info("Finished shutting down TPU system.")
237  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
238    del _INITIALIZED_TPU_SYSTEMS[tpu_name]
239