• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""TPU specific APIs to be used in conjunction with TPU Strategy."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.protobuf import config_pb2
22from tensorflow.python.client import session as session_lib
23from tensorflow.python.distribute.cluster_resolver.tpu_cluster_resolver import TPUClusterResolver
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import device
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.tpu import topology
31from tensorflow.python.tpu import tpu
32from tensorflow.python.util import compat
33from tensorflow.python.util.tf_export import tf_export
34
35
36_INITIALIZED_TPU_SYSTEMS = {}
37_LOCAL_MASTERS = ("", "local")
38
39
40@tf_export("tpu.experimental.initialize_tpu_system")
41def initialize_tpu_system(cluster_resolver=None):
42  """Initialize the TPU devices.
43
44  Args:
45    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
46        which provides information about the TPU cluster.
47  Returns:
48    The tf.tpu.Topology object for the topology of the TPU cluster. If called
49    inside tf.function, it returns the serialized topology object instead.
50
51  Raises:
52    RuntimeError: If running inside a tf.function.
53    NotFoundError: If no TPU devices found in eager mode.
54  """
55  job = None
56  if cluster_resolver is None:
57    # If no cluster resolver is specified, and running eagerly, execute the init
58    # ops in the current device scope.
59    if context.executing_eagerly():
60      curr_device = device.DeviceSpec.from_string(context.context().device_name)
61      if curr_device.job is not None:
62        job = "{}/replica:0/task:0".format(curr_device.job)
63
64    cluster_resolver = TPUClusterResolver("")
65  assert isinstance(cluster_resolver, TPUClusterResolver)
66
67  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
68  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
69    logging.warning(
70        "TPU system %s has already been initialized. "
71        "Reinitializing the TPU can cause previously created "
72        "variables on TPU to be lost.", tpu_name)
73
74  logging.info("Initializing the TPU system: %s", tpu_name)
75
76  # This function looks as it is for the following non-intuitive reasons.
77  # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
78  # DistributedTPURewritePass. This pass actually adds real ops that
79  # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
80  # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
81  if tpu_name not in _LOCAL_MASTERS:
82    # Explicitly place the tpu.initialize_system in the first worker to
83    # avoid the output node match multiple devices error.
84    job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
85
86  if context.executing_eagerly():
87    @function.defun
88    def _tpu_init_fn():
89      # In TF1, we usually close chips when compilation fails to clear the data
90      # in infeed. In TF2, we don't need to do this because infeed is no longer
91      # used, so user can recover from TPU compilation failures more smoothly.
92      return tpu.initialize_system(
93          job=job, compilation_failure_closes_chips=False)
94
95    # The TPU_SYSTEM device must match the device used in tpu.initialize_system
96    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
97    # devices available.
98    try:
99      with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
100        output = _tpu_init_fn()
101      context.async_wait()
102    except errors.InvalidArgumentError as e:
103      raise errors.NotFoundError(
104          None, None,
105          "TPUs not found in the cluster. Failed in initialization: "
106          + str(e))
107
108    # Clear out the eager context caches since the memory is invalid now.
109    logging.info("Clearing out eager caches")
110    context.context()._clear_caches()  # pylint: disable=protected-access
111    context.context()._initialize_logical_devices()  # pylint: disable=protected-access
112    context.context().clear_kernel_cache()
113
114    serialized_topology = output.numpy()
115  elif not ops.executing_eagerly_outside_functions():
116    master = cluster_resolver.master()
117    cluster_spec = cluster_resolver.cluster_spec()
118
119    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
120    if cluster_spec:
121      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
122
123    with ops.Graph().as_default():
124      with session_lib.Session(config=session_config, target=master) as sess:
125        serialized_topology = sess.run(tpu.initialize_system())
126  else:
127    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
128      serialized_topology = tpu.initialize_system(
129          job=job, compilation_failure_closes_chips=False)
130      # If initialize_tpu_system is called inside tf.function, we only return
131      # the serialized topology object as the tf.tpu.Topology object has to be
132      # constructed in eager mode.
133      return serialized_topology
134
135  logging.info("Finished initializing TPU system.")
136  tpu_topology = topology.Topology(serialized=serialized_topology)
137  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology
138
139  return tpu_topology
140
141
142@tf_export("tpu.experimental.shutdown_tpu_system")
143def shutdown_tpu_system(cluster_resolver=None):
144  """Shuts down the TPU devices.
145
146  This will clear all caches, even those that are maintained through sequential
147  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
148  cache.
149
150  Args:
151    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
152        which provides information about the TPU cluster.
153
154  Raises:
155    RuntimeError: If no TPU devices found for eager execution or if run in a
156        tf.function.
157  """
158  job = None
159  if cluster_resolver is None:
160    # If no cluster resolver is specified, and running eagerly, execute the init
161    # ops in the current device scope.
162    if context.executing_eagerly():
163      curr_device = device.DeviceSpec.from_string(context.context().device_name)
164      if curr_device.job is not None:
165        job = "{}/replica:0/task:0".format(curr_device.job)
166
167    cluster_resolver = TPUClusterResolver("")
168  assert isinstance(cluster_resolver, TPUClusterResolver)
169
170  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
171  if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
172    logging.warning("You are shutting down a TPU system %s that has not been "
173                    "initialized." % tpu_name)
174
175  logging.info("Shutting down the TPU system: %s", tpu_name)
176
177  if context.executing_eagerly():
178    # This function looks as it is for the following non-intuitive reasons.
179    # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
180    # DistributedTPURewritePass. This pass actually adds real ops that
181    # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
182    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
183    if tpu_name not in _LOCAL_MASTERS:
184      # Explicitly place the tpu.shutdown_system in the first worker to
185      # avoid the output node match multiple devices error.
186      job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())
187
188    @function.defun
189    def _tpu_shutdown_fn():
190      tpu.shutdown_system(job=job)
191
192    # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
193    # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
194    # devices available.
195    with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
196      _tpu_shutdown_fn()
197
198    # Clear out the eager context caches since the memory is invalid now.
199    logging.info("Clearing out eager caches")
200    context.context()._clear_caches()  # pylint: disable=protected-access
201    context.context().clear_kernel_cache()
202  elif not ops.executing_eagerly_outside_functions():
203    master = cluster_resolver.master()
204    cluster_spec = cluster_resolver.cluster_spec()
205
206    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
207    if cluster_spec:
208      session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
209
210    with ops.Graph().as_default():
211      with session_lib.Session(config=session_config, target=master) as sess:
212        sess.run(tpu.shutdown_system())
213  else:
214    raise RuntimeError("initialize_tpu_system is not supported within "
215                       "tf.functions.")
216
217  logging.info("Finished shutting down TPU system.")
218  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
219    del _INITIALIZED_TPU_SYSTEMS[tpu_name]
220