• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU specific APIs to be used in conjunction with TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.client import session as session_lib
23from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import device
27from tensorflow.python.framework import ops
28from tensorflow.python.platform import tf_logging as logging
29from tensorflow.python.tpu import topology
30from tensorflow.python.tpu import tpu
31from tensorflow.python.util import compat
32from tensorflow.python.util.tf_export import tf_export
33
34
35_INITIALIZED_TPU_SYSTEMS = {}
36_LOCAL_MASTERS = ("", "local")
37
38
39@tf_export("tpu.experimental.initialize_tpu_system")
40def initialize_tpu_system(cluster_resolver=None):
41  """Initialize the TPU devices.
42
43  Args:
44    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
45        which provides information about the TPU cluster.
46  Returns:
47    The tf.tpu.Topology object for the topology of the TPU cluster.
48
49  Raises:
50    RuntimeError: If no TPU devices found for eager execution or if run in a
51        tf.function.
52  """
53  job = None
54  if cluster_resolver is None:
55    # If no cluster resolver is specified, and running eagerly, execute the init
56    # ops in the current device scope.
57    if context.executing_eagerly():
58      curr_device = device.DeviceSpec.from_string(context.context().device_name)
59      if curr_device.job is not None:
60        job = "{}/replica:0/task:0".format(curr_device.job)
61
62    cluster_resolver = TPUClusterResolver("")
63  assert isinstance(cluster_resolver, TPUClusterResolver)
64
65  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
66  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
67    logging.warning(
68        "TPU system %s has already been initialized. "
69        "Reinitializing the TPU can cause previously created "
70        "variables on TPU to be lost.", tpu_name)
71
72  logging.info("Initializing the TPU system: %s", tpu_name)
73
74  if context.executing_eagerly():
75    # This function looks as it is for the following non-intuitive reasons.
76    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
77    # DistributedTPURewritePass. This pass actually adds real ops that
78    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
79    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
80    if tpu_name not in _LOCAL_MASTERS:
81      # Explicitly place the tpu.initialize_system in the first worker to
82      # avoid the output node match multiple devices error.
83      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
84
85    @function.defun
86    def _tpu_init_fn():
87      # In TF1, we usually close chips when compilation fails to clear the data
88      # in infeed. In TF2, we don't need to do this because infeed is no longer
89      # used, so user can recover from TPU compilation failures more smoothly.
90      return tpu.initialize_system(
91          job=job, compilation_failure_closes_chips=False)
92
93    # The TPU_SYSTEM device must match the device used in tpu.initialize_system
94    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
95    # devices available.
96    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
97      output = _tpu_init_fn()
98
99    # Clear out the eager context caches since the memory is invalid now.
100    logging.info("Clearing out eager caches")
101    context.context()._clear_caches()  # pylint: disable=protected-access
102
103    serialized_topology = output.numpy()
104
105    # TODO(b/134094971): Remove this when lazy tensor copy in multi-device
106    # function has been implemented.
107    context.context().mirroring_policy = context.MIRRORING_ALL
108  elif not ops.executing_eagerly_outside_functions():
109    master = cluster_resolver.master()
110    cluster_spec = cluster_resolver.cluster_spec()
111
112    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
113    if cluster_spec:
114      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
115
116    with ops.Graph().as_default():
117      with session_lib.Session(config=session_config, target=master) as sess:
118        serialized_topology = sess.run(tpu.initialize_system())
119  else:
120    raise RuntimeError("initialize_tpu_system is not supported within "
121                       "tf.functions.")
122
123  logging.info("Finished initializing TPU system.")
124  tpu_topology = topology.Topology(serialized=serialized_topology)
125  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
126
127  return tpu_topology
128
129
130@tf_export("tpu.experimental.shutdown_tpu_system")
131def shutdown_tpu_system(cluster_resolver=None):
132  """Shuts down the TPU devices.
133
134  This will clear all caches, even those that are maintained through sequential
135  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
136  cache.
137
138  Args:
139    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
140        which provides information about the TPU cluster.
141
142  Raises:
143    RuntimeError: If no TPU devices found for eager execution or if run in a
144        tf.function.
145  """
146  job = None
147  if cluster_resolver is None:
148    # If no cluster resolver is specified, and running eagerly, execute the init
149    # ops in the current device scope.
150    if context.executing_eagerly():
151      curr_device = device.DeviceSpec.from_string(context.context().device_name)
152      if curr_device.job is not None:
153        job = "{}/replica:0/task:0".format(curr_device.job)
154
155    cluster_resolver = TPUClusterResolver("")
156  assert isinstance(cluster_resolver, TPUClusterResolver)
157
158  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
159  if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
160    logging.warning("You are shutting down a TPU system %s that has not been "
161                    "initialized." % tpu_name)
162
163  logging.info("Shutting down the TPU system: %s", tpu_name)
164
165  if context.executing_eagerly():
166    # This function looks as it is for the following non-intuitive reasons.
167    # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
168    # DistributedTPURewritePass. This pass actually adds real ops that
169    # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
170    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
171    if tpu_name not in _LOCAL_MASTERS:
172      # Explicitly place the tpu.shutdown_system in the first worker to
173      # avoid the output node match multiple devices error.
174      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
175
176    @function.defun
177    def _tpu_shutdown_fn():
178      tpu.shutdown_system(job=job)
179
180    # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
181    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
182    # devices available.
183    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
184      _tpu_shutdown_fn()
185
186    # Clear out the eager context caches since the memory is invalid now.
187    logging.info("Clearing out eager caches")
188    context.context()._clear_caches()  # pylint: disable=protected-access
189  elif not ops.executing_eagerly_outside_functions():
190    master = cluster_resolver.master()
191    cluster_spec = cluster_resolver.cluster_spec()
192
193    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
194    if cluster_spec:
195      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
196
197    with ops.Graph().as_default():
198      with session_lib.Session(config=session_config, target=master) as sess:
199        sess.run(tpu.shutdown_system())
200  else:
201    raise RuntimeError("initialize_tpu_system is not supported within "
202                       "tf.functions.")
203
204  logging.info("Finished shutting down TPU system.")
205  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
206    del _INITIALIZED_TPU_SYSTEMS[tpu_name]
207