• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""TPUEstimator class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import os
24import signal
25import sys
26import threading
27import time
28
29import numpy as np
30import six
31from six.moves import queue as Queue  # pylint: disable=redefined-builtin
32from six.moves import xrange  # pylint: disable=redefined-builtin
33
34from tensorflow.core.framework import variable_pb2
35from tensorflow.core.framework.summary_pb2 import Summary
36from tensorflow.core.protobuf import config_pb2
37from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
38from tensorflow.python.client import session as tf_session
39from tensorflow.python.data.ops import dataset_ops
40from tensorflow.python.data.util import nest as data_nest
41from tensorflow.python.estimator import estimator as estimator_lib
42from tensorflow.python.estimator import model_fn as model_fn_lib
43from tensorflow.python.estimator.export import export_output as export_output_lib
44from tensorflow.python.framework import constant_op
45from tensorflow.python.framework import dtypes
46from tensorflow.python.framework import errors
47from tensorflow.python.framework import function
48from tensorflow.python.framework import ops
49from tensorflow.python.ops import array_ops
50from tensorflow.python.ops import check_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import init_ops
53from tensorflow.python.ops import math_ops
54from tensorflow.python.ops import resource_variable_ops
55from tensorflow.python.ops import state_ops
56from tensorflow.python.ops import summary_ops_v2 as contrib_summary
57from tensorflow.python.ops import variable_scope
58from tensorflow.python.ops import variables
59from tensorflow.python.platform import tf_logging as logging
60from tensorflow.python.saved_model import tag_constants
61from tensorflow.python.summary import summary
62from tensorflow.python.tpu import _tpu_estimator_embedding
63from tensorflow.python.tpu import error_handling
64from tensorflow.python.tpu import functional as tpu_functional
65from tensorflow.python.tpu import session_support
66from tensorflow.python.tpu import tensor_tracer
67from tensorflow.python.tpu import tpu
68from tensorflow.python.tpu import tpu_config
69from tensorflow.python.tpu import tpu_context
70from tensorflow.python.tpu import tpu_embedding_gradient
71from tensorflow.python.tpu import tpu_feed
72from tensorflow.python.tpu import tpu_function
73from tensorflow.python.tpu import training_loop
74from tensorflow.python.tpu import util as util_lib
75from tensorflow.python.tpu._tpu_estimator_embedding import AdagradParameters  # pylint: disable=unused-import
76from tensorflow.python.tpu._tpu_estimator_embedding import AdamParameters  # pylint: disable=unused-import
77from tensorflow.python.tpu._tpu_estimator_embedding import StochasticGradientDescentParameters  # pylint: disable=unused-import
78from tensorflow.python.tpu._tpu_estimator_embedding import EmbeddingConfigSpec  # pylint: disable=unused-import
79from tensorflow.python.tpu.ops import tpu_ops
80from tensorflow.python.training import basic_session_run_hooks
81from tensorflow.python.training import evaluation
82from tensorflow.python.training import session_run_hook
83from tensorflow.python.training import training
84from tensorflow.python.training import training_util
85from tensorflow.python.util import function_utils
86from tensorflow.python.util import nest
87from tensorflow.python.util import tf_inspect
88
89_INITIAL_LOSS = 1e7
90_ZERO_LOSS = 0.
91_TPU_ESTIMATOR = 'tpu_estimator'
92_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'
93_BATCH_SIZE_KEY = 'batch_size'
94_CTX_KEY = 'context'
95_USE_TPU_KEY = 'use_tpu'
96_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'
97_ONE_GIGABYTE = 1024 * 1024 * 1024
98_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'
99_TPU_TRAIN_OP = '_tpu_train_op'
100_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference'
101_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor'
102
103# Ideally _USE_TPU_KEY should be reserved as well. However there are already
104# models that make use of this key, thus it can not be reserved now to prevent
105# breakage. In the long run, we would like to mitigate this by migrating models
106# off of using _USE_TPU_KEY.
107_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]
108
109# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is
110# only used for per-core based deployments. For per-host based pipelines, if a
111# user returns a Dataset instance it will be automatically wrapped in a
112# tf.while_loop (This can be disabled by returning features and labels
113# explicitly).
114_WRAP_INPUT_FN_INTO_WHILE_LOOP = False
115
116ops.register_proto_function(
117    '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),
118    proto_type=variable_pb2.VariableDef,
119    to_proto=resource_variable_ops._to_proto_fn,  # pylint: disable=protected-access
120    from_proto=resource_variable_ops._from_proto_fn)  # pylint: disable=protected-access
121
122
123def _is_iterable(obj):
124  """A Python 2 and 3 compatible util to check whether `obj` is iterable."""
125  try:
126    iter(obj)
127    return True
128  except TypeError:
129    return False
130
131
132class CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext):
133
134  def AddOp(self, op):
135    if op.type in [
136        'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary',
137        'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2'
138    ]:
139      raise ValueError('Use tf.contrib.summary inside of host_calls.')
140
141
142def _create_global_step(graph):
143  graph = graph or ops.get_default_graph()
144  if training.get_global_step(graph) is not None:
145    raise ValueError('"global_step" already exists.')
146  # Create in proper graph and base name_scope.
147  with graph.as_default() as g, g.name_scope(None):
148    return variable_scope.get_variable(
149        ops.GraphKeys.GLOBAL_STEP,
150        shape=[],
151        dtype=dtypes.int64,
152        initializer=init_ops.zeros_initializer(),
153        trainable=False,
154        use_resource=True,
155        collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
156
157
158def _create_or_get_iterations_per_loop():
159  """Creates or gets the iterations_per_loop variable.
160
161  In TPUEstimator, the user provided computation, the model_fn, is wrapped
162  inside a tf.while_loop for peak performance. The iterations of the loop are
163  specified by this variable, which adjusts its value on the CPU after each TPU
164  program execution and before the next TPU execution.
165
166  The purpose of using a variable, rather then a constant, is to allow
167  TPUEstimator adapt the TPU training iterations according to the final steps
168  specified by users. For example, if the user sets the iterations_per_loop as 4
169  in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop
170  variable will have the following value before each TPU training.
171
172      - 1-th TPU execution: iterations_per_loop = 4
173      - 2-th TPU execution: iterations_per_loop = 4
174      - 3-th TPU execution: iterations_per_loop = 2
175
176  As model_fn increases the global step once per train_op invocation, the global
177  step is 10 after all TPU executions, matching the steps=10 inputs passed in by
178  users.
179
180  Returns:
181    A TF non-trainable resource variable.
182
183  Raises:
184    RuntimeError: If multi iterations_per_loop variables were found.
185  """
186  graph = ops.get_default_graph()
187  collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)
188  iter_vars = graph.get_collection(collection_name)
189  if len(iter_vars) == 1:
190    return iter_vars[0]
191  elif len(iter_vars) > 1:
192    raise RuntimeError('Multiple iterations_per_loop_var in collection.')
193
194  with ops.colocate_with(training_util.get_global_step()):
195    with variable_scope.variable_scope(
196        _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE):
197      return variable_scope.get_variable(
198          _ITERATIONS_PER_LOOP_VAR,
199          initializer=init_ops.zeros_initializer(),
200          shape=[],
201          dtype=dtypes.int32,
202          trainable=False,
203          collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES],
204          use_resource=True)
205
206
207def _sync_variables_ops(ctx):
208  """Create varriables synchronization ops.
209
210  Gets the variables back from TPU nodes. This means the variables updated
211  by TPU will now be *synced* to host memory.
212  In BROADCAST mode, we skip this sync since the variables are ususally too
213  big to transmit via RPC.
214
215  Args:
216    ctx: A `_InternalTPUContext` instance with mode.
217
218  Returns:
219    A list of sync ops.
220  """
221
222  if not ctx.is_input_broadcast_with_iterators():
223    return [
224        array_ops.check_numerics(v.read_value(),
225                                 'Gradient for %s is NaN' % v.name).op
226        for v in variables.trainable_variables()
227    ]
228  else:
229    return [control_flow_ops.no_op()]
230
231
232def _increase_eval_step_op(iterations_per_loop):
233  """Returns an op to increase the eval step for TPU evaluation.
234
235  Args:
236    iterations_per_loop: Tensor. The number of eval steps running in TPU system
237      before returning to CPU host for each `Session.run`.
238
239  Returns:
240    An operation
241  """
242  eval_step = evaluation._get_or_create_eval_step()  # pylint: disable=protected-access
243  # Estimator evaluate increases 1 by default. So, we increase the difference.
244  return state_ops.assign_add(
245      eval_step,
246      math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype),
247      use_locking=True)
248
249
250def _extract_key_names(tensor_or_dict):
251  if isinstance(tensor_or_dict, dict):
252    return sorted(tensor_or_dict.keys())
253  return []
254
255
256class _SIGNAL(object):
257  """Signal used to control the thread of infeed/outfeed.
258
259  All preserved signals must be negative numbers. Positive numbers are used to
260  indicate the number of iterations for next training/evaluation loop.
261  """
262  NEXT_BATCH = -1
263  STOP = -2
264
265
266class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
267  """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.
268
269  See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and
270  `export_outputs`.
271
272  For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where
273  `metric_fn` runs on CPU to generate metrics and `tensors` represents the
274  `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
275  To be precise, TPU evaluation expects a slightly different signature from the
276  `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
277  dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
278  The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
279  `tensors` usually specify the model logits, which are transferred back from
280  TPU system to CPU host. All tensors must have be batch-major, i.e., the batch
281  size is the first dimension. Once all tensors are available at CPU host from
282  all shards, they are concatenated (on CPU) and passed as positional arguments
283  to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is
284  a dict. `metric_fn` takes the `tensors` and returns a dict from metric string
285  name to the result of calling a metric function, namely a `(metric_tensor,
286  update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the
287  `eval_metrics`.
288
289  `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This
290  function should not capture any Tensors in `model_fn`.
291
292  `host_call` is a tuple of a `function` and a list or dictionary of `tensors`
293  to pass to that function and returns a list of Tensors. `host_call` currently
294  works for train() and evaluate(). The Tensors returned by the function is
295  executed on the CPU on every step, so there is communication overhead when
296  sending tensors from TPU to CPU. To reduce the overhead, try reducing the
297  size of the tensors. The `tensors` are concatenated along their major (batch)
298  dimension, and so must be >= rank 1. The `host_call` is useful for writing
299  summaries with `tf.contrib.summary.create_file_writer`.
300  """
301
302  def __new__(cls,
303              mode,
304              predictions=None,
305              loss=None,
306              train_op=None,
307              eval_metrics=None,
308              export_outputs=None,
309              scaffold_fn=None,
310              host_call=None,
311              training_hooks=None,
312              evaluation_hooks=None,
313              prediction_hooks=None):
314    """Creates a validated `TPUEstimatorSpec` instance."""
315    host_calls = {}
316    if eval_metrics is not None:
317      host_calls['eval_metrics'] = eval_metrics
318    if host_call is not None:
319      host_calls['host_call'] = host_call
320    _OutfeedHostCall.validate(host_calls)
321
322    training_hooks = tuple(training_hooks or [])
323    evaluation_hooks = tuple(evaluation_hooks or [])
324    prediction_hooks = tuple(prediction_hooks or [])
325
326    for hook in training_hooks + evaluation_hooks + prediction_hooks:
327      if not isinstance(hook, session_run_hook.SessionRunHook):
328        raise TypeError('All hooks must be SessionRunHook instances, given: {}'
329                        .format(hook))
330
331    return super(TPUEstimatorSpec, cls).__new__(
332        cls,
333        mode=mode,
334        predictions=predictions,
335        loss=loss,
336        train_op=train_op,
337        eval_metrics=eval_metrics,
338        export_outputs=export_outputs,
339        scaffold_fn=scaffold_fn,
340        host_call=host_call,
341        training_hooks=training_hooks,
342        evaluation_hooks=evaluation_hooks,
343        prediction_hooks=prediction_hooks)
344
345  def as_estimator_spec(self):
346    """Creates an equivalent `EstimatorSpec` used by CPU train/eval."""
347    host_calls = {}
348    if self.eval_metrics is not None:
349      host_calls['eval_metrics'] = self.eval_metrics
350    if self.host_call is not None:
351      host_calls['host_call'] = self.host_call
352    host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls)
353    eval_metric_ops = None
354    if self.eval_metrics is not None:
355      eval_metric_ops = host_call_ret['eval_metrics']
356    hooks = None
357    if self.host_call is not None:
358      hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
359    loss = self.loss
360    if tensor_tracer.TensorTracer.is_enabled() \
361       and self.train_op is not None:
362      tt = tensor_tracer.TensorTracer()
363      loss = tt.trace_cpu(ops.get_default_graph(), loss, self.train_op)
364
365    hooks = tuple(hooks or [])
366    scaffold = self.scaffold_fn() if self.scaffold_fn else None
367    return model_fn_lib.EstimatorSpec(
368        mode=self.mode,
369        predictions=self.predictions,
370        loss=loss,
371        train_op=self.train_op,
372        eval_metric_ops=eval_metric_ops,
373        export_outputs=self.export_outputs,
374        scaffold=scaffold,
375        training_hooks=self.training_hooks + hooks,
376        evaluation_hooks=self.evaluation_hooks + hooks,
377        prediction_hooks=self.prediction_hooks + hooks)
378
379
380class _OpQueueContext(object):
381  """Manages work queue and thread for a infeed/outfeed thread."""
382
383  def __init__(self, name, target, args):
384    self._name = name
385    self._queue = Queue.Queue()
386    args = (self,) + args
387    self._thread = threading.Thread(name=name, target=target, args=args)
388    self._thread.daemon = True
389    self._thread.start()
390
391  def stop(self):
392    self._queue.put(_SIGNAL.STOP)
393
394  def send_next_batch_signal(self, iterations):
395    self._queue.put(iterations)
396
397  def read_iteration_counts(self):
398    while True:
399      iterations = self._queue.get(block=True)
400      logging.debug('%s read iterations %s', self._name, iterations)
401      if iterations == _SIGNAL.STOP:
402        logging.info('%s received shutdown signal, stopping.', self._name)
403        return
404      yield iterations
405
406  def join(self):
407    logging.info('Shutting down %s thread.', self._name)
408    self.stop()
409    self._thread.join()
410
411
412class _OpSignalOnceQueueContext(_OpQueueContext):
413  """Manages work queue and thread for a infeed/outfeed thread.
414
415  This subclass only signals once.
416  """
417
418  def __init__(self, name, target, args):
419    super(_OpSignalOnceQueueContext, self).__init__(name, target, args)
420    self._has_signaled = False
421
422  def send_next_batch_signal(self, iterations):
423    if not self._has_signaled:
424      self._queue.put(iterations)
425      self._has_signaled = True
426
427
428class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
429  """A Session hook setting up the TPU initialization, infeed, and outfeed.
430
431  This hook does two major things:
432  1. initialize and shutdown TPU system.
433  2. launch and join the threads for infeed enqueue and (optional) outfeed
434     dequeue.
435  """
436
437  def __init__(self,
438               ctx,
439               enqueue_ops,
440               dequeue_ops,
441               tpu_compile_op,
442               run_infeed_loop_on_coordinator=True,
443               rendezvous=None,
444               master=None,
445               session_config=None,
446               tpu_init_ops=None):
447    self._master_job = ctx.master_job
448    self._enqueue_ops = enqueue_ops
449    self._dequeue_ops = dequeue_ops
450    self._rendezvous = rendezvous
451    self._master = master
452    self._session_config = session_config
453    self._init_ops = list(tpu_init_ops or [])
454    if ctx.embedding_config is None:
455      self._embedding_layer_config = None
456    else:
457      self._embedding_layer_config = (
458          ctx.embedding_config.tpu_embedding.config_proto)
459    self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator
460    self._initial_infeed_sleep_secs = (
461        ctx.config.tpu_config.initial_infeed_sleep_secs)
462
463    self._feed_error = None
464    self._finished = False
465    # When using model parallelism, the TPU is pre-initialized at startup to
466    # fetch mesh information.  We skip re-initializing it here to avoid
467    # suspected issues due to the mesh layout changing on the second
468    # initialization.
469    self._should_initialize_tpu = not ctx.model_parallelism_enabled
470    self._tpu_compile_op = tpu_compile_op
471
472  def begin(self):
473    logging.info('TPU job name %s', self._master_job)
474    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
475    if self._should_initialize_tpu:
476      self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
477    else:
478      self._finalize_ops = []
479
480    summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
481    self._init_ops.extend(summary_writer_init_ops)
482    # Get all the writer resources from the initializer, so we know what to
483    # flush.
484    for op in summary_writer_init_ops:
485      self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
486
487  def _run_infeed(self, queue_ctx, session):
488    logging.info('Starting infeed thread controller.')
489    if self._initial_infeed_sleep_secs:
490      logging.info('Infeed thread sleeping for %d seconds.',
491                   self._initial_infeed_sleep_secs)
492      time.sleep(self._initial_infeed_sleep_secs)
493      logging.info('Infeed thread starting after sleep')
494
495    with self._rendezvous.catch_errors(source='infeed', session=session):
496      if self._run_infeed_loop_on_coordinator:
497        for count, steps in enumerate(queue_ctx.read_iteration_counts()):
498          for i in xrange(steps):
499            logging.debug('Infeed enqueue for iteration (%d, %d)', count, i)
500            session.run(self._enqueue_ops)
501      else:
502        for _ in queue_ctx.read_iteration_counts():
503          session.run(self._enqueue_ops)
504      logging.info('Infeed thread finished, shutting down.')
505
506  def _run_outfeed(self, queue_ctx, session):
507    logging.info('Starting outfeed thread controller.')
508    with self._rendezvous.catch_errors(source='outfeed', session=session):
509      for count, steps in enumerate(queue_ctx.read_iteration_counts()):
510        for i in xrange(steps):
511          logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i)
512          session.run(self._dequeue_ops)
513      logging.info('Outfeed thread finished, shutting down.')
514
515  def _create_infeed_controller(self, name, target, args):
516    return _OpQueueContext(name=name, target=target, args=args)
517
518  def _assertCompilationSucceeded(self, result, coord):
519    proto = tpu_compilation_result.CompilationResultProto()
520    proto.ParseFromString(result)
521    if proto.status_error_message:
522      logging.error('Compilation failed: {}'.format(proto.status_error_message))
523      coord.request_stop()
524    else:
525      logging.info('Compilation succeeded')
526
527  def after_create_session(self, session, coord):
528    if self._should_initialize_tpu:
529      logging.info('Init TPU system')
530      start = time.time()
531      with ops.Graph().as_default():
532        with tf_session.Session(
533            self._master, config=self._session_config) as sess:
534          sess.run(
535              tpu.initialize_system(
536                  job=self._master_job,
537                  embedding_config=self._embedding_layer_config))
538      logging.info('Initialized TPU in %d seconds', time.time() - start)
539
540    session.run(self._init_ops,
541                options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
542
543    if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1':
544      logging.info('Compiling user program: this may take a while...')
545      self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord)
546
547    self._infeed_controller = self._create_infeed_controller(
548        name='InfeedController', target=self._run_infeed, args=(session,))
549
550    self._outfeed_controller = _OpQueueContext(
551        name='OutfeedController', target=self._run_outfeed, args=(session,))
552
553    # Enable the worker watchdog to terminate workers on coordinator exit.
554    watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0'))
555    if watchdog_timeout > 0:
556      session_support.start_worker_watchdog(session,
557                                            shutdown_timeout=watchdog_timeout)
558
559  def before_run(self, run_context):
560    self._feed_error = None
561
562    iterations = run_context.session.run(self._iterations_per_loop_var)
563
564    logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations)
565    self._infeed_controller.send_next_batch_signal(iterations)
566
567    logging.info('Dequeue next (%d) batch(es) of data from outfeed.',
568                 iterations)
569    self._outfeed_controller.send_next_batch_signal(iterations)
570
571  def end(self, session):
572    self._finished = True
573    logging.info('Stop infeed thread controller')
574    self._infeed_controller.join()
575    self._rendezvous.record_done('infeed')
576
577    logging.info('Stop output thread controller')
578    self._outfeed_controller.join()
579    self._rendezvous.record_done('outfeed')
580
581    logging.info('Shutdown TPU system.')
582    session.run(self._finalize_ops)
583
584
585class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):
586
587  def __init__(self, ctx, enqueue_ops, dequeue_ops, tpu_compile_op,
588               rendezvous=None, master=None, session_config=None):
589    super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(
590        ctx,
591        enqueue_ops,
592        dequeue_ops,
593        tpu_compile_op=tpu_compile_op,
594        run_infeed_loop_on_coordinator=False,
595        rendezvous=rendezvous,
596        master=master,
597        session_config=session_config)
598
599  def _create_infeed_controller(self, name, target, args):
600    return _OpSignalOnceQueueContext(name=name, target=target, args=args)
601
602
603class _TPUStopAtStepHook(session_run_hook.SessionRunHook):
604  """Hook that requests stop at a specified step.
605
606  This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with
607  following differences for TPU training:
608
609  1. This hook sets the variable for iterations_per_loop, which is used by
610     `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed.
611     As the hook execution order is not guaranteed, the variable update is
612     handled in `after_create_session` and `after_run` as
613     `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`.
614
615  2. For each training loop (session.run), the global step could be increased
616     multiple times on TPU. The global step tensor value will be explicitly read
617     again in `after_run` to ensure the latest value is retrieved to avoid race
618     condition.
619  """
620
621  def __init__(self, iterations, num_steps=None, last_step=None):
622    """Initializes a `StopAtStepHook`.
623
624    Args:
625      iterations: The number of iterations to run optimizer per training loop.
626      num_steps: Number of steps to execute.
627      last_step: Step after which to stop.
628
629    Raises:
630      ValueError: If one of the arguments is invalid.
631    """
632    if num_steps is None and last_step is None:
633      raise ValueError('One of num_steps or last_step must be specified.')
634    if num_steps is not None and last_step is not None:
635      raise ValueError('Only one of num_steps or last_step can be specified.')
636    self._num_steps = num_steps
637    self._last_step = last_step
638    self._iterations = iterations
639
640  def _next_iterations(self, global_step, last_step):
641    gap = last_step - global_step
642    return min(gap, self._iterations)
643
644  def begin(self):
645    self._global_step_tensor = training_util.get_global_step()
646    if self._global_step_tensor is None:
647      raise RuntimeError('Global step should be created.')
648
649    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
650
651  def after_create_session(self, session, coord):
652    global_step = session.run(self._global_step_tensor)
653    if self._last_step is None:
654      self._last_step = global_step + self._num_steps
655
656    iterations = self._next_iterations(global_step, self._last_step)
657
658    self._iterations_per_loop_var.load(iterations, session=session)
659
660  def after_run(self, run_context, run_values):
661    # Global step cannot be retrieved via SessionRunArgs and before_run due to
662    # race condition.
663    global_step = run_context.session.run(self._global_step_tensor)
664    if global_step >= self._last_step:
665      run_context.request_stop()
666    else:
667      iterations = self._next_iterations(global_step, self._last_step)
668      self._iterations_per_loop_var.load(
669          iterations, session=run_context.session)
670
671
672class _SetEvalIterationsHook(session_run_hook.SessionRunHook):
673  """Hook that requests stop at a specified step."""
674
675  def __init__(self, num_steps):
676    """Initializes a `_SetEvalIterationsHook`.
677
678    Args:
679      num_steps: Number of steps to execute.
680    """
681    self._num_steps = num_steps
682
683  def begin(self):
684    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
685
686  def after_create_session(self, session, coord):
687    self._iterations_per_loop_var.load(self._num_steps, session=session)
688
689
690class _StoppingPredictHook(session_run_hook.SessionRunHook):
691  """Hook that requests stop according to the stopping signal in prediction."""
692
693  def __init__(self, scalar_stopping_signal):
694    self._scalar_stopping_signal = scalar_stopping_signal
695
696  def begin(self):
697    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
698
699  def after_create_session(self, session, coord):
700    # This is not necessary as we do not run infeed enqueue and outfeed dequeue
701    # in side threads for prediction model. But it makes the
702    # TPUInfeedOutfeedSessionHook prints nice message.
703    self._iterations_per_loop_var.load(1, session=session)
704
705  def before_run(self, run_context):
706    return session_run_hook.SessionRunArgs(self._scalar_stopping_signal)
707
708  def after_run(self, run_context, run_values):
709    _ = run_context
710    scalar_stopping_signal = run_values.results
711    if _StopSignals.should_stop(scalar_stopping_signal):
712      # NOTE(xiejw): In prediction, stopping signals are inserted for each
713      # batch. And we append one more batch to signal the system it should stop.
714      # The data flow might look like
715      #
716      #  batch   0: images, labels, stop = 0  (user provided)
717      #  batch   1: images, labels, stop = 0  (user provided)
718      #  ...
719      #  batch  99: images, labels, stop = 0  (user provided)
720      #  batch 100: images, labels, stop = 1  (TPUEstimator appended)
721      #
722      # where the final batch (id = 100) is appended by TPUEstimator, so we
723      # should drop it before returning the predictions to user.
724      # To achieve that, we throw the OutOfRangeError in after_run. Once
725      # Monitored Session sees this error in SessionRunHook.after_run, the
726      # "current" prediction, i.e., batch with id=100, will be discarded
727      # immediately
728      raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')
729
730
731def generate_per_core_enqueue_ops_fn_for_host(
732    ctx, input_fn, inputs_structure_recorder, host_device, host_id):
733  """Generates infeed enqueue ops for per-core input_fn on a single host."""
734  captured_infeed_queue = _CapturedObject()
735  tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
736
737  def enqueue_ops_fn():
738    """A fn returns enqueue_ops."""
739    num_cores_per_host = ctx.num_of_cores_per_host
740    per_host_sharded_inputs = []
741    for core_ordinal in range(num_cores_per_host):
742      with ops.name_scope('ordinal_%d' % (core_ordinal)):
743        user_context = tpu_context.TPUContext(
744            internal_ctx=ctx,
745            input_device=host_device,
746            invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal)
747        inputs = _Inputs.from_input_fn(input_fn(user_context))
748        if inputs.is_dataset:
749          raise TypeError(
750              '`input_fn` returning `Dataset`  is not yet supported in '
751              'per-Core input pipeline deployment yet. Please set '
752              'TPUConfig.per_host_input_for_training to True or return '
753              '`features` and `labels` from `input_fn`')
754        features, labels = inputs.features_and_labels()
755
756        inputs_structure_recorder.validate_and_record_structure(
757            features, labels)
758        flattened_inputs = (
759            inputs_structure_recorder.flatten_features_and_labels(
760                features, labels))
761        per_host_sharded_inputs.append(flattened_inputs)
762
763    infeed_queue = tpu_feed.InfeedQueue(
764        number_of_tuple_elements=len(per_host_sharded_inputs[0]))
765    captured_infeed_queue.capture(infeed_queue)
766
767    per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
768        per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)
769    return per_host_enqueue_ops
770
771  return enqueue_ops_fn, captured_infeed_queue
772
773
774def generate_per_host_enqueue_ops_fn_for_host(
775    ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id):
776  """Generates infeed enqueue ops for per-host input_fn on a single host."""
777  captured_infeed_queue = _CapturedObject()
778
779  dataset_initializer = None
780
781  with ops.device(device):
782    user_context = tpu_context.TPUContext(
783        internal_ctx=ctx, input_device=device, invocation_index=host_id)
784    inputs = _Inputs.from_input_fn(input_fn(user_context))
785
786    is_dataset = inputs.is_dataset
787    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
788      if not is_dataset:
789        raise TypeError(
790            'For mode PREDICT, `input_fn` must return `Dataset` instead of '
791            '`features` and `labels`.')
792      if batch_axis is not None:
793        raise TypeError('For mode PREDICT, batch_axis is not supported yet.')
794      inputs = _InputsWithStoppingSignals(
795          dataset=inputs.dataset,
796          batch_size=ctx.batch_size_for_input_fn,
797          add_padding=True)
798
799    if is_dataset:
800      dataset_initializer = inputs.dataset_initializer()
801
802    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
803
804  def enqueue_ops_fn():
805    """A Fn returning the TPU infeed enqueue ops.
806
807    By providing as a Fn, it can be invoked inside the tf.while_loop such that
808    the input pipeline for multiple iterations can be executed by one
809    Session.run call.
810
811    Returns:
812      list of dict of ops.
813    """
814    with ops.device(device):
815      num_of_replicas_per_host = ctx.num_of_replicas_per_host
816      # Convert user input to features and labels.  If the user returns a
817      # dataset, it is initialized and the features and labels extracted via
818      # `dataset.iterator.get_next()`
819      features, labels = inputs.features_and_labels()
820      signals = inputs.signals()
821
822      inputs_structure_recorder.validate_and_record_structure(features, labels)
823      unsharded_tensor_list = (
824          inputs_structure_recorder.flatten_features_and_labels(
825              features, labels, signals))
826
827      infeed_queue = tpu_feed.InfeedQueue(
828          tuple_types=[t.dtype for t in unsharded_tensor_list],
829          tuple_shapes=[t.shape for t in unsharded_tensor_list],
830          shard_dimensions=batch_axis)
831      captured_infeed_queue.capture(infeed_queue)
832      infeed_queue.set_number_of_shards(num_of_replicas_per_host)
833      per_host_enqueue_ops = (
834          infeed_queue.split_inputs_and_generate_enqueue_ops(
835              unsharded_tensor_list,
836              placement_function=lambda x: device,
837              tpu_ordinal_function=tpu_ordinal_function_impl))
838      if signals is None:
839        return per_host_enqueue_ops
840      else:
841        return {
842            'ops': per_host_enqueue_ops,
843            'signals': signals,
844        }
845
846  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer
847
848
849def generate_per_host_v2_enqueue_ops_fn_for_host(
850    ctx, input_fn, inputs_structure_recorder, device, host_id):
851  """Generates infeed enqueue ops for per-host input_fn on a single host."""
852  captured_infeed_queue = _CapturedObject()
853  dataset_initializer = None
854
855  with ops.device(device):
856    user_context = tpu_context.TPUContext(
857        internal_ctx=ctx, input_device=device, invocation_index=host_id)
858    inputs = _Inputs.from_input_fn(input_fn(user_context))
859
860    is_dataset = inputs.is_dataset
861    if not is_dataset:
862      raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '
863                      'input pipeline configuration.')
864
865    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
866      inputs = _InputsWithStoppingSignals(
867          dataset=inputs.dataset,
868          batch_size=ctx.batch_size_for_input_fn,
869          add_padding=True,
870          num_invocations_per_step=ctx.num_of_replicas_per_host)
871
872    dataset_initializer = inputs.dataset_initializer()
873    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)
874
875  def enqueue_ops_fn():
876    """Generates the per_host enqueue ops."""
877    control_deps = []
878    per_host_sharded_inputs = []
879    sparse_features_list = []
880    num_replicas_per_host = ctx.num_of_replicas_per_host
881    cached_signals = None
882    with ops.device(device):
883      if not inputs.is_dataset:
884        raise TypeError('`input_fn` must return a `Dataset` for this mode.')
885      for _ in range(num_replicas_per_host):
886        # Use control dependencies to ensure a deterministic ordering.
887        with ops.control_dependencies(control_deps):
888          features, labels = inputs.features_and_labels()  # Calls get_next()
889          signals = inputs.signals()
890
891          # All the replicas share the replica 0's stopping singal.
892          # This avoids inconsistent state among different model replcias.
893          if cached_signals:
894            signals['stopping'] = cached_signals['stopping']
895          else:
896            cached_signals = signals
897
898        features, labels, sparse_features = (
899            _tpu_estimator_embedding.split_inputs(ctx, features, labels))
900        sparse_features_list.append(sparse_features)
901
902        inputs_structure_recorder.validate_and_record_structure(
903            features, labels)
904        flattened_inputs = (
905            inputs_structure_recorder.flatten_features_and_labels(
906                features, labels, signals))
907        control_deps.extend(flattened_inputs)
908        per_host_sharded_inputs.append(flattened_inputs)
909
910      if inputs_structure_recorder.flattened_input_dims:
911        input_partition_dims = inputs_structure_recorder.flattened_input_dims
912        if signals:
913          input_partition_dims += [None] * len(signals)
914        # pylint: disable=protected-access
915        infeed_queue = tpu_feed._PartitionedInfeedQueue(
916            number_of_tuple_elements=len(per_host_sharded_inputs[0]),
917            host_id=host_id,
918            input_partition_dims=input_partition_dims,
919            device_assignment=ctx.device_assignment)
920        per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
921            per_host_sharded_inputs)
922      else:
923        infeed_queue = tpu_feed.InfeedQueue(
924            number_of_tuple_elements=len(per_host_sharded_inputs[0]))
925        per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
926            per_host_sharded_inputs,
927            tpu_ordinal_function=tpu_ordinal_function_impl)
928      captured_infeed_queue.capture(infeed_queue)
929
930    if ctx.embedding_config:
931      per_host_enqueue_ops.extend(
932          ctx.embedding_config.tpu_embedding.generate_enqueue_ops(
933              sparse_features_list))
934
935    if signals is None:
936      return per_host_enqueue_ops
937    else:
938      return {
939          'ops': per_host_enqueue_ops,
940          'signals': signals,
941      }
942
943  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer
944
945
946def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
947                                      num_hosts):
948  """Generates infeed enqueue ops for one input_fn on all the hosts."""
949  captured_infeed_queue = _CapturedObject()
950  dataset_initializer = None
951  device_0 = ctx.tpu_host_placement_function(host_id=0)
952  with ops.device(device_0):
953    user_context = tpu_context.TPUContext(
954        internal_ctx=ctx, input_device=device_0, invocation_index=0)
955    inputs = _Inputs.from_input_fn(input_fn(user_context))
956
957    is_dataset = inputs.is_dataset
958    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:
959      if not is_dataset:
960        raise TypeError(
961            'For mode PREDICT, `input_fn` must return `Dataset` instead of '
962            '`features` and `labels`.')
963
964      inputs = _InputsWithStoppingSignals(
965          dataset=inputs.dataset,
966          batch_size=ctx.batch_size_for_input_fn,
967          add_padding=True)
968
969    if is_dataset:
970      dataset_initializer = inputs.dataset_initializer()
971    num_replicas_per_host = ctx.num_of_replicas_per_host
972
973  def tpu_ordinal_function_impl(replica_id):
974    if ctx.device_assignment:
975      return ctx.device_assignment.tpu_ordinal(replica=replica_id)
976    else:
977      return replica_id % num_replicas_per_host
978
979  def device_function_impl(replica_id):
980    return ctx.tpu_host_placement_function(replica_id=replica_id)
981
982  def enqueue_ops_fn():
983    """Generates enqueue ops for all the hosts."""
984    broadcasted_inputs = []
985    flattened_inputs = None  # Cache result from input_fn.
986    signals = None
987    num_replicas = ctx.num_replicas
988    core_id = 0
989    for host_id in xrange(num_hosts):
990      with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
991        for _ in xrange(ctx.num_of_replicas_per_host):
992          # Note: input_fn is only called once at host 0 for the first replica.
993          # The features and labels returned from that invocation are
994          # broadcasted to other replicas(including the replicas on other
995          # hosts).
996          if flattened_inputs is None:
997            features, labels = inputs.features_and_labels()  # Calls get_next()
998            signals = inputs.signals()
999
1000            inputs_structure_recorder.validate_and_record_structure(
1001                features, labels)
1002            flattened_inputs = (
1003                inputs_structure_recorder.flatten_features_and_labels(
1004                    features, labels, signals))
1005            if (ctx.config.tpu_config.eval_training_input_configuration is
1006                tpu_config.InputPipelineConfig.SLICED):
1007              input_slices = [
1008                  array_ops.split(x, num_replicas) for x in flattened_inputs
1009              ]
1010          if (ctx.config.tpu_config.eval_training_input_configuration is
1011              tpu_config.InputPipelineConfig.SLICED):
1012            # for each core, slice out the flattened_inputs for each core.
1013            broadcasted_inputs.append([x[core_id] for x in input_slices])
1014            core_id += 1
1015          else:
1016            broadcasted_inputs.append(flattened_inputs)
1017
1018    infeed_queue = tpu_feed.InfeedQueue(
1019        number_of_tuple_elements=len(broadcasted_inputs[0]))
1020    captured_infeed_queue.capture(infeed_queue)
1021    enqueue_ops = infeed_queue.generate_enqueue_ops(
1022        broadcasted_inputs,
1023        tpu_ordinal_function=tpu_ordinal_function_impl,
1024        placement_function=device_function_impl)
1025
1026    if signals is None:
1027      return enqueue_ops
1028    else:
1029      return {
1030          'ops': enqueue_ops,
1031          'signals': signals,
1032      }
1033
1034  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer
1035
1036
1037class _InputPipeline(object):
1038  """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.
1039
1040  `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from
1041  call site.  To be precise, based on the configuration in
1042  `_InternalTPUContext`,  it invokes `input_fn` for all cores (usually
1043  multi-host TPU training) or for one host (usually for single-host TPU
1044  evaluation), and sends all `features` and `labels` returned by `input_fn` to
1045  TPU infeed. For per-core invocation, `features` and `labels` are piped to
1046  infeed directly, one tuple for each core. For per-host invocation,  `features`
1047  and `labels` are split at host (with respect to `batch_axis`) and piped to all
1048  cores accordingly.
1049
1050  In addition, flatten/unflatten are handled by `_InputPipeline` also.  Model
1051  inputs returned by the `input_fn` can have one of the following forms:
1052  1. features
1053  2. (features, labels)
1054  3. ((arbitrarily nested structure of features), labels)
1055
1056  Internally, form 1 is reformed to `(features, None)` as features and labels
1057  are passed separately to underlying methods. For TPU training, TPUEstimator
1058  may expect multiple `features` and `labels` tuples one for each core.
1059
1060  TPUEstimator allows various different structures for inputs (namely `features`
1061  and `labels`).  Both `features` and `labels` can be any nested sturcture
1062  supported by TF nest (namely, dict, tuples, namedtuples or any nested
1063  structure of such of Tensors).  `labels` could be `None` as well.
1064
1065  These are flattened before they are passed to the infeed/outfeed library
1066  as that expectes flattend lists.
1067  """
1068
1069  class InputsStructureRecorder(object):
1070    """The recorder to record inputs structure."""
1071
1072    def __init__(self, input_partition_dims=None):
1073      # Holds the structure of inputs
1074      self._feature_structure = {}
1075      self._flattened_input_dims = None
1076
1077      if input_partition_dims:
1078        # This should have been validated in TPUConfig.
1079        assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.'
1080        if len(input_partition_dims) == 2:
1081          self._feature_dims, self._label_dims = input_partition_dims
1082        else:
1083          self._feature_dims = input_partition_dims[0]
1084          self._label_dims = None
1085
1086        assert self._feature_dims is not None, ('input_partition_dims[0] must '
1087                                                'not be None')
1088      else:
1089        self._feature_dims = None
1090        self._label_dims = None
1091
1092      # Internal state.
1093      self._initialized = False
1094
1095    @property
1096    def flattened_input_dims(self):
1097      assert self._initialized, 'InputsStructureRecorder is not initialized.'
1098      return self._flattened_input_dims
1099
1100    def has_labels(self):
1101      return 'labels' in self._feature_structure
1102
1103    def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims,
1104                            label_dims_names, label_names, has_labels):
1105      """Flatten input dims with the same order as flattened input tensors."""
1106      flattened_input_dims = []
1107      if feature_dims_names:
1108        # We need a fixed ordering for matching the tensors in features.
1109        flattened_input_dims.extend(
1110            [feature_dims[name] for name in feature_dims_names])
1111      else:
1112        flattened_input_dims.append(feature_dims)
1113
1114      if label_dims_names:
1115        # We need a fixed ordering for matching the tensors in labels.
1116        flattened_input_dims.extend(
1117            [label_dims[name] for name in label_dims_names])
1118      else:
1119        if label_names:
1120          num_tensors_in_label = len(label_names)
1121        else:
1122          num_tensors_in_label = int(has_labels)
1123        # Setting `None` in input_partition_dims[1] will apply `None` to
1124        # all the tensors in labels, regardless of internal structure.
1125        flattened_input_dims.extend([label_dims] * num_tensors_in_label)
1126
1127      return flattened_input_dims
1128
1129    def validate_and_record_structure(self, features, labels):
1130      """Validates and records the structure of `features` and `labels`."""
1131      # Extract structure.
1132      has_labels = labels is not None
1133      feature_names = _extract_key_names(features)
1134      label_names = _extract_key_names(labels)
1135
1136      if not self._initialized:
1137        # Record structure.
1138        self._initialized = True
1139        if self._feature_dims is not None:
1140          feature_dims_names = _extract_key_names(self._feature_dims)
1141          if feature_dims_names != feature_names:
1142            raise ValueError(
1143                'TPUConfig.input_partition_dims[0] mismatched feature'
1144                ' keys. Expected {}, got {}'.format(feature_names,
1145                                                    feature_dims_names))
1146
1147          label_dims_names = _extract_key_names(self._label_dims)
1148          if self._label_dims is not None and label_dims_names != label_names:
1149            raise ValueError(
1150                'TPUConfig.input_partition_dims[1] mismatched label'
1151                ' keys. Expected {}, got {}'.format(label_names,
1152                                                    label_dims_names))
1153
1154          self._flattened_input_dims = self._flatten_input_dims(
1155              self._feature_dims, feature_dims_names, self._label_dims,
1156              label_dims_names, label_names, has_labels)
1157
1158    def flatten_features_and_labels(self, features, labels, signals=None):
1159      """Flattens the `features` and `labels` to a single tensor list."""
1160      self._feature_structure['features'] = features
1161      if labels is not None:
1162        self._feature_structure['labels'] = labels
1163      if signals is not None:
1164        self._feature_structure['signals'] = signals
1165      return data_nest.flatten(self._feature_structure)
1166
1167    def unflatten_features_and_labels(self, flattened_inputs):
1168      """Restores the flattened inputs to original features and labels form.
1169
1170      Args:
1171        flattened_inputs: Flattened inputs for each shard.
1172
1173      Returns:
1174        A tuple of (`features`, `labels`), where `labels` could be None.
1175        Each one, if present, should have identical structure (single tensor vs
1176        dict) as the one returned by input_fn.
1177
1178      Raises:
1179        ValueError: If the number of expected tensors from `flattened_inputs`
1180          mismatches the recorded structure.
1181      """
1182
1183      unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
1184                                                      flattened_inputs)
1185      return _Inputs(
1186          unflattened_inputs['features'],
1187          unflattened_inputs.get('labels'),
1188          signals=unflattened_inputs.get('signals'))
1189
1190  def __init__(self, input_fn, batch_axis, ctx):
1191    """Constructor.
1192
1193    Args:
1194      input_fn: input fn for train or eval.
1195      batch_axis: A python tuple of int values describing how each tensor
1196        produced by the Estimator `input_fn` should be split across the TPU
1197        compute shards.
1198      ctx: A `_InternalTPUContext` instance with mode.
1199
1200    Raises:
1201      ValueError: If both `sharded_features` and `num_cores` are `None`.
1202    """
1203    self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder(
1204        ctx.input_partition_dims)
1205
1206    self._sharded_per_core = ctx.is_input_sharded_per_core()
1207    self._input_fn = input_fn
1208    self._infeed_queue = None
1209    self._ctx = ctx
1210    self._batch_axis = batch_axis
1211
1212  def generate_infeed_enqueue_ops_and_dequeue_fn(self):
1213    """Generates infeed enqueue ops and dequeue_fn."""
1214    # While tf.while_loop is called, the body function, which invokes
1215    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn
1216    # structure is recorded.
1217    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (
1218        self._invoke_input_fn_and_record_structure())
1219
1220    self._validate_input_pipeline()
1221
1222    def dequeue_fn():
1223      """dequeue_fn is used by TPU to retrieve the tensors."""
1224      # In the model-parallel case, both the host-side and device-side
1225      # computations must agree on the core on which infeed takes place. We
1226      # choose to perform infeed on logical core 0 of each replica.
1227      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
1228      # The unflatten process uses the structure information recorded above.
1229      return self._inputs_structure_recorder.unflatten_features_and_labels(
1230          values)
1231
1232    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator)
1233
1234  def _invoke_input_fn_and_record_structure(self):
1235    """Deploys the input pipeline and record input structure."""
1236    enqueue_ops = []
1237    infeed_queues = []
1238    all_dataset_initializers = []
1239    num_hosts = self._ctx.num_hosts
1240    tpu_host_placement_fn = self._ctx.tpu_host_placement_function
1241
1242    run_infeed_loop_on_coordinator = True
1243
1244    if self._sharded_per_core:
1245      # Per-Core input pipeline deployment.
1246      # Invoke input pipeline for each core and placed on the corresponding
1247      # host.
1248      for host_id in range(num_hosts):
1249        host_device = tpu_host_placement_fn(host_id=host_id)
1250        with ops.device(host_device):
1251          with ops.name_scope('input_pipeline_task%d' % (host_id)):
1252            enqueue_ops_fn, captured_infeed_queue = (
1253                generate_per_core_enqueue_ops_fn_for_host(
1254                    self._ctx, self._input_fn, self._inputs_structure_recorder,
1255                    host_device, host_id))
1256
1257            if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
1258              run_infeed_loop_on_coordinator = False
1259              enqueue_ops.append(
1260                  _wrap_computation_in_while_loop(
1261                      device=host_device, op_fn=enqueue_ops_fn))
1262            else:
1263              enqueue_ops.append(enqueue_ops_fn())
1264            # Infeed_queue_getter must be called after enqueue_ops_fn is called.
1265            infeed_queues.append(captured_infeed_queue.get())
1266
1267    elif self._ctx.is_input_broadcast_with_iterators():
1268      # Only calls input_fn in host 0.
1269      host_device = tpu_host_placement_fn(host_id=0)
1270      enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
1271          generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,
1272                                            self._inputs_structure_recorder,
1273                                            num_hosts))
1274      if dataset_initializer:
1275        all_dataset_initializers.append(dataset_initializer)
1276        run_infeed_loop_on_coordinator = False
1277        wrap_fn = (
1278            _wrap_computation_in_while_loop
1279            if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
1280            _wrap_computation_in_while_loop_with_stopping_signals)
1281        enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
1282      else:
1283        enqueue_ops.append(enqueue_ops_fn())
1284      infeed_queues.append(captured_infeed_queue.get())
1285    else:
1286      for host_id in range(num_hosts):
1287        host_device = tpu_host_placement_fn(host_id=host_id)
1288        with ops.device(host_device):
1289          with ops.name_scope('input_pipeline_task%d' % (host_id)):
1290            if self._ctx.is_input_per_host_with_iterators():
1291              enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
1292                  generate_per_host_v2_enqueue_ops_fn_for_host(
1293                      self._ctx, self._input_fn,
1294                      self._inputs_structure_recorder, host_device, host_id))
1295            else:
1296              enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (
1297                  generate_per_host_enqueue_ops_fn_for_host(
1298                      self._ctx, self._input_fn,
1299                      self._inputs_structure_recorder, self._batch_axis,
1300                      host_device, host_id))
1301
1302            # NOTE(xiejw): We dispatch here based on the return type of the
1303            # users `input_fn`.
1304            #
1305            # 1. If input_fn returns a Dataset instance, we initialize the
1306            # iterator outside of tf.while_loop, and call the iterator.get_next
1307            # inside tf.while_loop.  This should be always safe.
1308            #
1309            # 2. If input_fn returns (features, labels), it is too late to wrap
1310            # them inside tf.while_loop, as resource initialization cannot be
1311            # handled in TF control flow properly. In this case, we will use
1312            # python loop to enqueue the data into TPU system.  This may be
1313            # slow compared to the previous case.
1314            if dataset_initializer:
1315              all_dataset_initializers.append(dataset_initializer)
1316              run_infeed_loop_on_coordinator = False
1317              wrap_fn = (
1318                  _wrap_computation_in_while_loop
1319                  if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else
1320                  _wrap_computation_in_while_loop_with_stopping_signals)
1321              enqueue_ops.append(
1322                  wrap_fn(device=host_device, op_fn=enqueue_ops_fn))
1323            else:
1324              enqueue_ops.append(enqueue_ops_fn())
1325            infeed_queues.append(captured_infeed_queue.get())
1326    # infeed_queue is used to generate dequeue ops. The only thing it uses for
1327    # dequeue is dtypes and types. So, any one can be used. Here, grab the
1328    # first one.
1329    self._infeed_queue = infeed_queues[0]
1330    return enqueue_ops, [
1331        util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers)
1332    ], run_infeed_loop_on_coordinator
1333
1334  def _validate_input_pipeline(self):
1335    """Validates the input pipeline.
1336
1337    Perform some sanity checks to log user friendly information. We should
1338    error out to give users better error message. But, if
1339    _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break
1340    user code, so, log a warning.
1341
1342    Raises:
1343      RuntimeError: If the validation failed.
1344    """
1345    if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS):
1346      err_msg = ('Input pipeline contains one or more QueueRunners. '
1347                 'It could be slow and not scalable. Please consider '
1348                 'converting your input pipeline to use `tf.data` instead (see '
1349                 'https://www.tensorflow.org/guide/datasets for '
1350                 'instructions.')
1351      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:
1352        raise RuntimeError(err_msg)
1353      else:
1354        logging.warn(err_msg)
1355
1356
1357def call_computation(computation,
1358                     experimental_exported_model_uses_all_cores=True):
1359  """Call computation.
1360
1361  computation uses a single-core for TPU inference. If
1362  `experimental_exported_model_uses_all_cores` is `True`, this function will
1363  round-robin
1364  computation among all TPU cores visible to the host; otherwise, it will use
1365  a single core.
1366
1367  Args:
1368    computation: A Python function that takes no inputs and builds computation
1369      graph. If `computation` returns m outputs, this function will return a
1370      list of m Tensors.
1371    experimental_exported_model_uses_all_cores: Whether to round-robin among all
1372      cores visible to the host, or to use a single core.
1373
1374  Returns:
1375    A list of output tensors.
1376  """
1377  if experimental_exported_model_uses_all_cores:
1378    # Using `TPUPartitionedCall` makes it possible to target a different
1379    # TPU core with every `Session.run()` call. Note that the entire inference
1380    # graph executes on a single core, and that invocations of this graph
1381    # will round-robin among the cores attached to a host.
1382    @function.Defun(capture_resource_var_by_value=False)
1383    def tpu_subgraph():
1384      return computation()
1385
1386    return tpu_functional.TPUPartitionedCall(
1387        args=tpu_subgraph.captured_inputs,
1388        device_ordinal=tpu_ops.tpu_ordinal_selector(),
1389        Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg],
1390        f=tpu_subgraph)
1391  else:
1392    return computation()
1393
1394
1395class _ModelFnWrapper(object):
1396  """A `model_fn` wrapper.
1397
1398  This makes calling model_fn on CPU and TPU easier and more consistent and
1399  performs necessary check and mutation required by TPU training and evaluation.
1400
1401  In addition, this wrapper manages converting the `model_fn` to a single TPU
1402  train and eval step.
1403  """
1404
1405  def __init__(self, model_fn, config, params, ctx):
1406    self._model_fn = model_fn
1407    self._config = config
1408    self._params = params
1409    self._ctx = ctx
1410
1411  def call_without_tpu(self, features, labels, is_export_mode):
1412    return self._call_model_fn(features, labels, is_export_mode=is_export_mode)
1413
1414  def _add_embedding_features(self, features, hook_dummy_table_variables):
1415    """Add embedding features, optionally add hook to intercept gradient."""
1416    if self._ctx.embedding_config:
1417      tpu_embedding_ = self._ctx.embedding_config.tpu_embedding
1418      embedding_activations = tpu_embedding_.get_activations()
1419      if hook_dummy_table_variables:
1420        new_embedding_activations = (
1421            tpu_embedding_gradient.hook_dummy_table_variables_to_activations(
1422                tpu_embedding_, embedding_activations,
1423                self._ctx.embedding_config.dummy_table_variables))
1424        features.update(new_embedding_activations)
1425      else:
1426        features.update(embedding_activations)
1427
1428  def convert_to_single_tpu_train_step(self, dequeue_fn):
1429    """Converts user provided model_fn` as a single train step on TPU.
1430
1431    The user provided `model_fn` takes input tuple
1432    (features, labels) and produces the EstimatorSpec with train_op and loss for
1433    train `mode`. This usually represents a single train computation on CPU.
1434
1435    For TPU training, a train (computation) step is first wrapped in a
1436    tf.while_loop control flow to repeat for many times and then replicated to
1437    all TPU shards. Besides the input should be taken from TPU infeed rather
1438    than input pipeline (input_fn) directly. To fit TPU loop and replicate
1439    pattern, the original train computation should be reformed, which is the
1440    returned `train_step`.
1441
1442    Args:
1443      dequeue_fn: The function to retrieve inputs, features and labels, from TPU
1444        infeed dequeue channel.
1445
1446    Returns:
1447      A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn
1448      representing the train step for TPU.
1449    """
1450
1451    host_call = _OutfeedHostCall(self._ctx)
1452    captured_scaffold_fn = _CapturedObject()
1453    captured_training_hooks = _CapturedObject()
1454
1455    def train_step(loss):
1456      """Training step function for use inside a while loop."""
1457      del loss  # unused; required in function signature.
1458      inputs = dequeue_fn()
1459      features, labels = inputs.features_and_labels()
1460      self._add_embedding_features(features, True)
1461
1462      estimator_spec = self._verify_estimator_spec(
1463          self._call_model_fn(features, labels))
1464      loss, train_op = estimator_spec.loss, estimator_spec.train_op
1465
1466      if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
1467        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
1468      else:
1469        captured_scaffold_fn.capture(None)
1470
1471      captured_training_hooks.capture(estimator_spec.training_hooks)
1472
1473      if self._ctx.embedding_config is None:
1474        apply_sparse_grads = []
1475      else:
1476        tpu_embedding_ = self._ctx.embedding_config.tpu_embedding
1477        gradients = (
1478            tpu_embedding_gradient.get_gradients_through_dummy_table_variables(
1479                tpu_embedding_)
1480        )
1481        apply_sparse_grads = [
1482            tpu_embedding_.generate_send_gradients_op(gradients)
1483        ]
1484
1485      # We must run train_op to update the variables prior to running the
1486      # outfeed.
1487      with ops.control_dependencies([train_op] + apply_sparse_grads):
1488        host_call_outfeed_ops = []
1489        if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)  # pylint: disable=protected-access
1490            and estimator_spec.host_call is not None):
1491          host_call.record({'host_call': estimator_spec.host_call})
1492          host_call_outfeed_ops = host_call.create_enqueue_op()
1493        with ops.control_dependencies(host_call_outfeed_ops):
1494          return array_ops.identity(loss)
1495
1496    return (train_step, host_call, captured_scaffold_fn,
1497            captured_training_hooks)
1498
1499  def convert_to_single_tpu_eval_step(self, dequeue_fn):
1500    """Converts user provided model_fn` as a single eval step on TPU.
1501
1502    Similar to training, the user provided `model_fn` takes input tuple
1503    (features, labels) and produces the TPUEstimatorSpec with eval_metrics for
1504    eval `mode`. This usually represents a single evaluation computation on CPU.
1505
1506    For TPU evaluation, a eval (computation) step is first wrapped in a
1507    tf.while_loop control flow to repeat for many times and then replicated to
1508    all TPU shards. Besides the input and output are slightly different. Input,
1509    features and labels, should be taken from TPU infeed rather than input
1510    pipeline (input_fn) directly. Output is managed in two stages.  First, the
1511    model outputs as the result of evaluation computation, usually model logits,
1512    should be transferred from TPU system to CPU. Then, all model outputs are
1513    concatenated first on CPU and sent to the metric_fn for metrics computation.
1514    To fit TPU evaluation pattern, the original eval computation should be
1515    reformed, which is the returned `eval_step`.
1516
1517    Args:
1518      dequeue_fn: The function to retrieve inputs, features and labels, from TPU
1519        infeed dequeue channel.
1520
1521    Returns:
1522      A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn
1523      representing the eval step for TPU.
1524    """
1525    host_calls = _OutfeedHostCall(self._ctx)
1526    captured_scaffold_fn = _CapturedObject()
1527    captured_eval_hooks = _CapturedObject()
1528
1529    def eval_step(total_loss):
1530      """Evaluation step function for use inside a while loop."""
1531      inputs = dequeue_fn()
1532      features, labels = inputs.features_and_labels()
1533      self._add_embedding_features(features, False)
1534
1535      tpu_estimator_spec = self._call_model_fn(features, labels)
1536      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
1537        raise RuntimeError(
1538            'estimator_spec used by TPU evaluation must have type'
1539            '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
1540
1541      loss = tpu_estimator_spec.loss
1542      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
1543      captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks)
1544
1545      to_record = {}
1546      if tpu_estimator_spec.eval_metrics:
1547        to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics
1548      if tpu_estimator_spec.host_call is not None:
1549        # We assume that evaluate won't update global step, so we don't wrap
1550        # this host_call.
1551        to_record['host_call'] = tpu_estimator_spec.host_call
1552      host_calls.record(to_record)
1553
1554      with ops.control_dependencies(host_calls.create_enqueue_op()):
1555        return math_ops.add(total_loss, loss)
1556
1557    return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
1558
1559  def convert_to_single_tpu_predict_step(self, dequeue_fn):
1560    """Converts user provided model_fn` as a single predict step on TPU.
1561
1562    Args:
1563      dequeue_fn: The function to retrieve inputs, features and labels, from TPU
1564        infeed dequeue channel.
1565
1566    Returns:
1567      A tuple of predict_fn, host_calls, and captured scaffold_fn. The
1568      predict_fn representing the predict step for TPU.
1569    """
1570    host_calls = _OutfeedHostCall(self._ctx)
1571    captured_scaffold_fn = _CapturedObject()
1572    captured_predict_hooks = _CapturedObject()
1573
1574    def predict_step(unused_scalar_stopping_signal):
1575      """Evaluation step function for use inside a while loop."""
1576      inputs = dequeue_fn()
1577      features, labels = inputs.features_and_labels()
1578      stopping_signals = inputs.signals()
1579
1580      assert stopping_signals is not None, (
1581          'Internal Error: `signals` is missing.')
1582
1583      tpu_estimator_spec = self._call_model_fn(
1584          features, labels, is_export_mode=False)
1585      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
1586        raise RuntimeError(
1587            'estimator_spec used by TPU prediction must have type'
1588            '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))
1589
1590      self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)
1591
1592      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)
1593      captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks)
1594      to_record = {}
1595      identity_fn = lambda **kwargs: kwargs
1596      to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]
1597      to_record['signals'] = [identity_fn, stopping_signals]
1598      if tpu_estimator_spec.host_call is not None:
1599        to_record['host_call'] = tpu_estimator_spec.host_call
1600      host_calls.record(to_record)
1601
1602      with ops.control_dependencies(host_calls.create_enqueue_op()):
1603        return _StopSignals.as_scalar_stopping_signal(stopping_signals)
1604
1605    return (predict_step, host_calls, captured_scaffold_fn,
1606            captured_predict_hooks)
1607
1608  def _verify_tpu_spec_predictions(self, predictions):
1609    """Validates TPUEstimatorSpec.predictions dict."""
1610    # TODO(xiejw): Adds validation for prediction dictionrary.
1611    # TODO(xiejw): Adds support for single tensor as predictions.
1612    if not isinstance(predictions, dict):
1613      raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')
1614
1615    for (key, tensor) in predictions.items():
1616      if tensor.shape.dims[0].value is None:
1617        raise ValueError(
1618            'The tensor with key ({}) in TPUEstimatorSpec.predictions has '
1619            'dynamic shape (should be static). Tensor: {}'.format(key, tensor))
1620    return predictions
1621
1622  def _validate_model_features_and_labels(self, features, labels,
1623                                          is_export_mode):
1624    """Validates that the features and labels for the model function are valid.
1625
1626    A valid features/labels object is the one with:
1627    - Type: A tensor or any nested structure of tensors supported by TF nest,
1628        namely nested dictionary, tuple, namedtuple, or sequence of tensors.
1629    - Static shape if is_export_mode is False.
1630
1631    Args:
1632      features: the features that would be input to the model function.
1633      labels: the labels that would be input to the model function.
1634      is_export_mode: boolean value specifying if in export mode.
1635
1636    Raises:
1637      TypeError: If features/labels are not of the correct type.
1638      ValueError: If features/labels have dynamic shape.
1639    """
1640
1641    def validate(obj, obj_name):
1642      """Helper validate function."""
1643      if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):
1644        return
1645      if isinstance(obj, ops.Tensor):
1646        if not obj.get_shape().is_fully_defined():
1647          raise ValueError(
1648              'The {} to the model returned by input_fn must have static shape.'
1649              ' Tensor: {}'.format(obj_name, obj))
1650      else:
1651        for tensor in data_nest.flatten(obj):
1652          if not tensor.get_shape().is_fully_defined():
1653            raise ValueError(
1654                ('The {} to the model returned by input_fn must have static '
1655                 'shape. Tensor: {}').format(obj_name, tensor))
1656
1657    validate(features, 'features')
1658    if labels is not None:
1659      validate(labels, 'labels')
1660
1661  def _call_model_fn(self, features, labels, is_export_mode=False):
1662    """Calls the model_fn with required parameters."""
1663    self._validate_model_features_and_labels(features, labels, is_export_mode)
1664    model_fn_args = function_utils.fn_args(self._model_fn)
1665    kwargs = {}
1666
1667    # Makes deep copy with `config` and params` in case user mutates them.
1668    config = copy.deepcopy(self._config)
1669    params = copy.deepcopy(self._params)
1670
1671    if 'labels' in model_fn_args:
1672      kwargs['labels'] = labels
1673    elif labels is not None:
1674      raise ValueError(
1675          'model_fn does not take labels, but input_fn returns labels.')
1676    if 'mode' in model_fn_args:
1677      kwargs['mode'] = self._ctx.mode
1678    if 'config' in model_fn_args:
1679      kwargs['config'] = config
1680    if 'params' in model_fn_args:
1681      kwargs['params'] = params
1682
1683    if 'params' not in model_fn_args:
1684      raise ValueError('model_fn ({}) does not include params argument, '
1685                       'required by TPUEstimator to pass batch size as '
1686                       'params[\'batch_size\']'.format(self._model_fn))
1687
1688    if is_export_mode:
1689      batch_size_for_model_fn = None
1690    else:
1691      batch_size_for_model_fn = self._ctx.batch_size_for_model_fn
1692
1693    if batch_size_for_model_fn is not None:
1694      _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)
1695
1696    running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)
1697    # In export mode, params['use_tpu'] has already been set based on mode
1698    # (i.e. True for _REWRITE_FOR_INFERENCE_MODE, False otherwise).
1699    if not is_export_mode:
1700      _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)
1701
1702    if not running_on_cpu:
1703      user_context = tpu_context.TPUContext(
1704          internal_ctx=self._ctx, call_from_input_fn=False)
1705      _add_item_to_params(params, _CTX_KEY, user_context)
1706
1707    estimator_spec = self._model_fn(features=features, **kwargs)
1708    if (running_on_cpu and
1709        isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)):  # pylint: disable=protected-access
1710      # The estimator_spec will be passed to `Estimator` directly, which expects
1711      # type `EstimatorSpec`.
1712      return estimator_spec.as_estimator_spec()
1713    else:
1714      return estimator_spec
1715
1716  def _verify_estimator_spec(self, estimator_spec):
1717    """Validates the estimator_spec."""
1718    if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access
1719      return estimator_spec
1720
1721    err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'
1722    if estimator_spec.training_chief_hooks:
1723      raise ValueError(
1724          err_msg.format('training_chief_hooks') + 'If you want' +
1725          ' to pass training hooks, please pass via training_hooks.')
1726
1727    if estimator_spec.scaffold:
1728      logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. '
1729                      'Please use TPUEstimatorSpec.')
1730    return estimator_spec
1731
1732
1733class _OutfeedHostCall(object):
1734  """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec."""
1735
1736  def __init__(self, ctx):
1737    self._ctx = ctx
1738    self._names = []
1739    # All of these are dictionaries of lists keyed on the name.
1740    self._host_fns = {}
1741    self._tensor_keys = collections.defaultdict(list)
1742    self._tensors = collections.defaultdict(list)
1743    self._tensor_dtypes = collections.defaultdict(list)
1744    self._tensor_shapes = collections.defaultdict(list)
1745
1746  @staticmethod
1747  def validate(host_calls):
1748    """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`."""
1749
1750    for name, host_call in host_calls.items():
1751      if not isinstance(host_call, (tuple, list)):
1752        raise ValueError('{} should be tuple or list'.format(name))
1753      if len(host_call) != 2:
1754        raise ValueError('{} should have two elements.'.format(name))
1755      if not callable(host_call[0]):
1756        raise TypeError('{}[0] should be callable.'.format(name))
1757      if not isinstance(host_call[1], (tuple, list, dict)):
1758        raise ValueError('{}[1] should be tuple or list, or dict.'.format(name))
1759
1760      if isinstance(host_call[1], (tuple, list)):
1761        fullargspec = tf_inspect.getfullargspec(host_call[0])
1762        fn_args = function_utils.fn_args(host_call[0])
1763        # wrapped_hostcall_with_global_step uses varargs, so we allow that.
1764        if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):
1765          raise RuntimeError(
1766              'In TPUEstimatorSpec.{}, length of tensors {} does not match '
1767              'method args of the function, which takes {}.'.format(
1768                  name, len(host_call[1]), len(fn_args)))
1769
1770  @staticmethod
1771  def create_cpu_hostcall(host_calls):
1772    """Runs on the host_call on CPU instead of TPU when use_tpu=False."""
1773
1774    _OutfeedHostCall.validate(host_calls)
1775    ret = {}
1776    for name, host_call in host_calls.items():
1777      host_fn, tensors = host_call
1778      if isinstance(tensors, (tuple, list)):
1779        ret[name] = host_fn(*tensors)
1780      else:
1781        # Must be dict.
1782        try:
1783          ret[name] = host_fn(**tensors)
1784        except TypeError as e:
1785          logging.warning(
1786              'Exception while calling %s: %s. It is likely the tensors '
1787              '(%s[1]) do not match the '
1788              'function\'s arguments', name, e, name)
1789          raise
1790    return ret
1791
1792  def record(self, host_calls):
1793    """Records the host_call structure."""
1794
1795    for name, host_call in host_calls.items():
1796      host_fn, tensor_list_or_dict = host_call
1797      self._names.append(name)
1798      self._host_fns[name] = host_fn
1799
1800      if isinstance(tensor_list_or_dict, dict):
1801        for (key, tensor) in six.iteritems(tensor_list_or_dict):
1802          self._tensor_keys[name].append(key)
1803          self._tensors[name].append(tensor)
1804          self._tensor_dtypes[name].append(tensor.dtype)
1805          self._tensor_shapes[name].append(tensor.shape)
1806      else:
1807        # List or tuple.
1808        self._tensor_keys[name] = None
1809        for tensor in tensor_list_or_dict:
1810          self._tensors[name].append(tensor)
1811          self._tensor_dtypes[name].append(tensor.dtype)
1812          self._tensor_shapes[name].append(tensor.shape)
1813
1814  def create_enqueue_op(self):
1815    """Create the op to enqueue the recorded host_calls.
1816
1817    Returns:
1818      A list of enqueue ops, which is empty if there are no host calls.
1819    """
1820    if not self._names:
1821      return []
1822
1823    tensors = []
1824    # TODO(jhseu): Consider deduping tensors.
1825    for name in self._names:
1826      tensors.extend(self._tensors[name])
1827
1828    with ops.device(tpu.core(0)):
1829      return [tpu_ops.outfeed_enqueue_tuple(tensors)]
1830
1831  def create_tpu_hostcall(self):
1832    """Sends the tensors through outfeed and runs the host_fn on CPU.
1833
1834    The tensors are concatenated along dimension 0 to form a global tensor
1835    across all shards. The concatenated function is passed to the host_fn and
1836    executed on the first host.
1837
1838    Returns:
1839      A dictionary mapping name to the return type of the host_call by that
1840      name.
1841
1842    Raises:
1843      RuntimeError: If outfeed tensor is scalar.
1844    """
1845    if not self._names:
1846      return {}
1847
1848    ret = {}
1849    # For each i, dequeue_ops[i] is a list containing the tensors from all
1850    # shards. This list is concatenated later.
1851    dequeue_ops = []
1852    tensor_dtypes = []
1853    tensor_shapes = []
1854    for name in self._names:
1855      for _ in self._tensors[name]:
1856        dequeue_ops.append([])
1857      for dtype in self._tensor_dtypes[name]:
1858        tensor_dtypes.append(dtype)
1859      for shape in self._tensor_shapes[name]:
1860        tensor_shapes.append(shape)
1861
1862    # Outfeed ops execute on each replica's first logical core. Note: we must
1863    # constraint it such that we have at most one outfeed dequeue and enqueue
1864    # per replica.
1865    for i in xrange(self._ctx.num_replicas):
1866      host_device, ordinal_id = self._ctx.device_for_replica(i)
1867      with ops.device(host_device):
1868        outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
1869            dtypes=tensor_dtypes,
1870            shapes=tensor_shapes,
1871            device_ordinal=ordinal_id)
1872        for j, item in enumerate(outfeed_tensors):
1873          dequeue_ops[j].append(item)
1874
1875    # Deconstruct dequeue ops.
1876    flat_dequeue_ops = []
1877    for l in dequeue_ops:
1878      flat_dequeue_ops.extend(l)
1879
1880    dequeue_ops_by_name = {}
1881    pos = 0
1882    for name in self._names:
1883      dequeue_ops_by_name[name] = dequeue_ops[pos:pos +
1884                                              len(self._tensors[name])]
1885      pos += len(self._tensors[name])
1886
1887    def _call_host_fn(fn, *args, **kw):
1888      context = CatchInvalidHostcallFunctions()
1889      context.Enter()
1890      result = fn(*args, **kw)
1891      context.Exit()
1892      context.ExitResult(result)
1893      return result
1894
1895    # It is assumed evaluation always happens on single host TPU system. So,
1896    # place all ops on tpu host if possible.
1897    #
1898    # TODO(jhseu): Evaluate whether this is right for summaries.
1899    with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
1900      for name in self._names:
1901        dequeue_ops = dequeue_ops_by_name[name]
1902        for i, item in enumerate(dequeue_ops):
1903          if dequeue_ops[i][0].shape.ndims == 0:
1904            raise RuntimeError(
1905                'All tensors outfed from TPU should preserve batch size '
1906                'dimension, but got scalar {}'.format(dequeue_ops[i][0]))
1907          # TODO(xiejw): Make the specification of the outfeed combinaton
1908          # function more explicit and well-documented.  We may want to give the
1909          # user the option of concatenating along any axis.
1910          if (self._ctx.config.tpu_config.per_host_input_for_training is
1911              tpu_config.InputPipelineConfig.BROADCAST):
1912            # If the infeed is in BROADCAST mode (each core recieving the same
1913            # input), then we assume that the cores also produce identical
1914            # copies of the same output, and we simply take the output from
1915            # the first core.  This mode is used by Mesh-TensorFlow.
1916            with ops.control_dependencies(dequeue_ops[i]):
1917              dequeue_ops[i] = array_ops.identity(dequeue_ops[i][0])
1918          else:
1919            # Assume that the input has been batch-split and that axis 0 of the
1920            # output tensors represents the batch size.  Concatenate along
1921            # the axis 0 to re-combine the batch.
1922            dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0)
1923
1924        if self._tensor_keys[name] is not None:
1925          # The user-provided eval_metrics[1] is a dict.
1926          dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops))
1927          try:
1928            ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops)
1929          except TypeError as e:
1930            logging.warning(
1931                'Exception while calling %s: %s. It is likely the tensors '
1932                '(%s[1]) do not match the '
1933                'function\'s arguments', name, e, name)
1934            raise
1935        else:
1936          ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops)
1937
1938    # force all dequeue operations to be run if not consumed by the host calls
1939    ret['__force_dequeue'] = control_flow_ops.group(*flat_dequeue_ops)
1940    return ret
1941
1942
1943class _OutfeedHostCallHook(session_run_hook.SessionRunHook):
1944  """Hook to run host calls when use_tpu=False."""
1945
1946  def __init__(self, tensors):
1947    self._tensors = tensors
1948
1949  def begin(self):
1950    # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than
1951    # create a separate hook to guarantee execution order, because summaries
1952    # need to be initialized before the outfeed thread starts.
1953    # TODO(jhseu): Make a wrapper hook instead?
1954    self._init_ops = contrib_summary.summary_writer_initializer_op()
1955    # Get all the writer resources from the initializer, so we know what to
1956    # flush.
1957    self._finalize_ops = []
1958    for op in self._init_ops:
1959      self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0]))
1960
1961  def after_create_session(self, session, coord):
1962    session.run(self._init_ops)
1963
1964  def before_run(self, run_context):
1965    return basic_session_run_hooks.SessionRunArgs(self._tensors)
1966
1967  def end(self, session):
1968    session.run(self._finalize_ops)
1969
1970
1971class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook):
1972  """Calculate and report global_step/sec and examples/sec during runtime."""
1973
1974  def __init__(self,
1975               batch_size,
1976               every_n_steps=100,
1977               every_n_secs=None,
1978               output_dir=None,
1979               summary_writer=None):
1980    self._batch_size = batch_size
1981    super(ExamplesPerSecondHook, self).__init__(
1982        every_n_steps=every_n_steps,
1983        every_n_secs=every_n_secs,
1984        output_dir=output_dir,
1985        summary_writer=summary_writer)
1986
1987  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):
1988    global_step_per_sec = elapsed_steps / elapsed_time
1989    examples_per_sec = self._batch_size * global_step_per_sec
1990    if self._summary_writer is not None:
1991      global_step_summary = Summary(value=[
1992          Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec)
1993      ])
1994      example_summary = Summary(value=[
1995          Summary.Value(tag='examples/sec', simple_value=examples_per_sec)
1996      ])
1997      self._summary_writer.add_summary(global_step_summary, global_step)
1998      self._summary_writer.add_summary(example_summary, global_step)
1999    logging.info('global_step/sec: %g', global_step_per_sec)
2000    logging.info('examples/sec: %g', examples_per_sec)
2001
2002
2003class InstallSignalHandlerHook(session_run_hook.SessionRunHook):
2004  """Change SIGINT (CTRL^C) handler to force quit the process.
2005
2006  The default behavior often results in hanging processes.
2007  The original handler is restored after training/evaluation.
2008  """
2009
2010  def __init__(self):
2011    self._signal_fn = signal.getsignal(signal.SIGINT)
2012
2013  def before_run(self, run_context):
2014    signal.signal(signal.SIGINT, signal.SIG_DFL)
2015
2016  def end(self, session):
2017    signal.signal(signal.SIGINT, self._signal_fn)
2018
2019
2020class TPUEstimator(estimator_lib.Estimator):
2021  """Estimator with TPU support.
2022
2023  TPUEstimator also supports training on CPU and GPU. You don't need to define
2024  a separate `tf.estimator.Estimator`.
2025
2026  TPUEstimator handles many of the details of running on TPU devices, such as
2027  replicating inputs and models for each core, and returning to host
2028  periodically to run hooks.
2029
2030  TPUEstimator transforms a global batch size in params to a per-shard batch
2031  size when calling the `input_fn` and `model_fn`. Users should specify
2032  global batch size in constructor, and then get the batch size for each shard
2033  in `input_fn` and `model_fn` by `params['batch_size']`.
2034
2035  - For training, `model_fn` gets per-core batch size; `input_fn` may get
2036    per-core or per-host batch size depending on `per_host_input_for_training`
2037    in `TPUConfig` (See docstring for TPUConfig for details).
2038
2039  - For evaluation and prediction, `model_fn` gets per-core batch size and
2040    `input_fn` get per-host batch size.
2041
2042  Evaluation
2043  ==========
2044
2045  `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`
2046  for TPU evaluation. If eval_on_tpu is False, the evaluation will execute on
2047  CPU or GPU; in this case the following discussion on TPU evaluation does not
2048  apply.
2049
2050  `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where
2051  `tensors` could be a list of any nested structure of `Tensor`s (See
2052  `TPUEstimatorSpec` for details).  `metric_fn` takes the `tensors` and returns
2053  a dict from metric string name to the result of calling a metric function,
2054  namely a `(metric_tensor, update_op)` tuple.
2055
2056  One can set `use_tpu` to `False` for testing. All training, evaluation, and
2057  predict will be executed on CPU. `input_fn` and `model_fn` will receive
2058  `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.
2059
2060  Current limitations:
2061  --------------------
2062
2063  1. TPU evaluation only works on a single host (one TPU worker) except
2064     BROADCAST mode.
2065
2066  2. `input_fn` for evaluation should **NOT** raise an end-of-input exception
2067     (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all
2068     batches should have the same size.
2069
2070  Example (MNIST):
2071  ----------------
2072
2073  ```
2074  # The metric Fn which runs on CPU.
2075  def metric_fn(labels, logits):
2076    predictions = tf.argmax(logits, 1)
2077    return {
2078      'accuracy': tf.metrics.precision(
2079          labels=labels, predictions=predictions),
2080    }
2081
2082  # Your model Fn which runs on TPU (eval_metrics is list in this example)
2083  def model_fn(features, labels, mode, config, params):
2084    ...
2085    logits = ...
2086
2087    if mode = tf.estimator.ModeKeys.EVAL:
2088      return tpu_estimator.TPUEstimatorSpec(
2089          mode=mode,
2090          loss=loss,
2091          eval_metrics=(metric_fn, [labels, logits]))
2092
2093  # or specify the eval_metrics tensors as dict.
2094  def model_fn(features, labels, mode, config, params):
2095    ...
2096    final_layer_output = ...
2097
2098    if mode = tf.estimator.ModeKeys.EVAL:
2099      return tpu_estimator.TPUEstimatorSpec(
2100          mode=mode,
2101          loss=loss,
2102          eval_metrics=(metric_fn, {
2103              'labels': labels,
2104              'logits': final_layer_output,
2105          }))
2106  ```
2107
2108  Prediction
2109  ==========
2110
2111  Prediction on TPU is an experimental feature to support large batch inference.
2112  It is not designed for latency-critical system. In addition, due to some
2113  usability issues, for prediction with small dataset, CPU `.predict`, i.e.,
2114  creating a new `TPUEstimator` instance with `use_tpu=False`, might be more
2115  convenient.
2116
2117  Note: In contrast to TPU training/evaluation, the `input_fn` for prediction
2118  *should* raise an end-of-input exception (`OutOfRangeError` or
2119  `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be
2120  precise, the ops created by `input_fn` produce one batch of the data.
2121  The `predict()` API processes one batch at a time. When reaching the end of
2122  the data source, an end-of-input exception should be raised by one of these
2123  operations. The user usually does not need to do this manually. As long as the
2124  dataset is not repeated forever, the `tf.data` API will raise an end-of-input
2125  exception automatically after the last batch has been produced.
2126
2127  Note: Estimator.predict returns a Python generator. Please consume all the
2128  data from the generator so that TPUEstimator can shutdown the TPU system
2129  properly for user.
2130
2131  Current limitations:
2132  --------------------
2133  1. TPU prediction only works on a single host (one TPU worker).
2134
2135  2. `input_fn` must return a `Dataset` instance rather than `features`. In
2136  fact, .train() and .evaluate() also support Dataset as return value.
2137
2138  Example (MNIST):
2139  ----------------
2140  ```
2141  height = 32
2142  width = 32
2143  total_examples = 100
2144
2145  def predict_input_fn(params):
2146    batch_size = params['batch_size']
2147
2148    images = tf.random_uniform(
2149        [total_examples, height, width, 3], minval=-1, maxval=1)
2150
2151    dataset = tf.data.Dataset.from_tensor_slices(images)
2152    dataset = dataset.map(lambda images: {'image': images})
2153
2154    dataset = dataset.batch(batch_size)
2155    return dataset
2156
2157  def model_fn(features, labels, params, mode):
2158     # Generate predictions, called 'output', from features['image']
2159
2160    if mode == tf.estimator.ModeKeys.PREDICT:
2161      return tf.contrib.tpu.TPUEstimatorSpec(
2162          mode=mode,
2163          predictions={
2164              'predictions': output,
2165              'is_padding': features['is_padding']
2166          })
2167
2168  tpu_est = TPUEstimator(
2169      model_fn=model_fn,
2170      ...,
2171      predict_batch_size=16)
2172
2173  # Fully consume the generator so that TPUEstimator can shutdown the TPU
2174  # system.
2175  for item in tpu_est.predict(input_fn=input_fn):
2176    # Filter out item if the `is_padding` is 1.
2177    # Process the 'predictions'
2178  ```
2179
2180  Exporting
2181  =========
2182
2183  `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`,
2184  and another with `tag_constants.SERVING` and `tag_constants.TPU`.
2185  At serving time, these tags are used to select metagraph to load.
2186
2187  Before running the graph on TPU, TPU system needs to be initialized. If
2188  TensorFlow Serving model-server is used, this is done automatically. If
2189  not, please call `session.run(tpu.initialize_system())`.
2190
2191  `tpu.outside_compilation` can be used to wrap TPU incompatible ops in
2192  `model_fn`.
2193
2194  Example:
2195  ----------------
2196
2197  ```
2198  def model_fn(features, labels, mode, config, params):
2199    ...
2200    logits = ...
2201    export_outputs = {
2202      'logits': export_output_lib.PredictOutput(
2203        {'logits': logits})
2204    }
2205
2206    def host_call(logits):
2207      class_ids = math_ops.argmax(logits)
2208      classes = string_ops.as_string(class_ids)
2209      export_outputs['classes'] =
2210        export_output_lib.ClassificationOutput(classes=classes)
2211
2212    tpu.outside_compilation(host_call, logits)
2213
2214    ...
2215  ```
2216
2217  """
2218
2219  def __init__(self,
2220               model_fn=None,
2221               model_dir=None,
2222               config=None,
2223               params=None,
2224               use_tpu=True,
2225               train_batch_size=None,
2226               eval_batch_size=None,
2227               predict_batch_size=None,
2228               batch_axis=None,
2229               eval_on_tpu=True,
2230               export_to_tpu=True,
2231               export_to_cpu=True,
2232               warm_start_from=None,
2233               experimental_exported_model_uses_all_cores=False,
2234               experimental_export_device_assignment=False,
2235               experimental_embedding_config_spec=None):
2236    """Constructs an `TPUEstimator` instance.
2237
2238    Args:
2239      model_fn: Model function as required by `Estimator` which returns
2240        EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',
2241        and `prediction_hooks` must not capure any TPU Tensor inside the
2242        model_fn.
2243      model_dir: Directory to save model parameters, graph and etc. This can
2244        also be used to load checkpoints from the directory into a estimator to
2245        continue training a previously saved model. If `None`, the model_dir in
2246        `config` will be used if set. If both are set, they must be same. If
2247        both are `None`, a temporary directory will be used.
2248      config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.
2249      params: An optional `dict` of hyper parameters that will be passed into
2250        `input_fn` and `model_fn`.  Keys are names of parameters, values are
2251        basic python types. There are reserved keys for `TPUEstimator`,
2252        including 'batch_size'.
2253      use_tpu: A bool indicating whether TPU support is enabled. Currently, -
2254        TPU training and evaluation respect this bit, but eval_on_tpu can
2255        override execution of eval. See below. - Predict still happens on CPU.
2256      train_batch_size: An int representing the global training batch size.
2257        TPUEstimator transforms this global batch size to a per-shard batch
2258        size, as params['batch_size'], when calling `input_fn` and `model_fn`.
2259        Cannot be `None` if `use_tpu` is `True`. Must be divisible by total
2260        number of replicas.
2261      eval_batch_size: An int representing evaluation batch size. Must be
2262        divisible by total number of replicas.
2263      predict_batch_size: An int representing the prediction batch size. Must be
2264        divisible by total number of replicas.
2265      batch_axis: A python tuple of int values describing how each tensor
2266        produced by the Estimator `input_fn` should be split across the TPU
2267        compute shards. For example, if your input_fn produced (images, labels)
2268        where the images tensor is in `HWCN` format, your shard dimensions would
2269        be [3, 0], where 3 corresponds to the `N` dimension of your images
2270        Tensor, and 0 corresponds to the dimension along which to split the
2271        labels to match up with the corresponding images. If None is supplied,
2272        and per_host_input_for_training is True, batches will be sharded based
2273        on the major dimension. If tpu_config.per_host_input_for_training is
2274        False or `PER_HOST_V2`, batch_axis is ignored.
2275      eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the
2276        model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.
2277      export_to_tpu: If True, `export_savedmodel()` exports a metagraph for
2278        serving on TPU. Note that unsupported export modes such as EVAL will be
2279        ignored. For those modes, only a CPU model will be exported.
2280        Currently, export_to_tpu only supports PREDICT.
2281      export_to_cpu: If True, `export_savedmodel()` exports a metagraph for
2282        serving on CPU.
2283      warm_start_from: Optional string filepath to a checkpoint or SavedModel to
2284        warm-start from, or a `tf.estimator.WarmStartSettings` object to fully
2285        configure warm-starting.  If the string filepath is provided instead of
2286        a `WarmStartSettings`, then all variables are warm-started, and it is
2287        assumed that vocabularies and Tensor names are unchanged.
2288      experimental_exported_model_uses_all_cores: Whether to round-robin among
2289        all cores visible to the host which is serving the saved model, or to
2290        use a single core. This is a temporary flag to enable using all TPU
2291        cores for inference with TPUPartitionedCall(). Once outside compilation
2292        is supported in TPUPartitionedCall(), this flag will be enabled by
2293        default.
2294      experimental_export_device_assignment: Whether to include the device
2295        assignment in the exported model. Doing so is useful in case of model
2296        parallel inference but will tie the exported model to the TPU topology
2297        used to export the model.
2298      experimental_embedding_config_spec: Optional EmbeddingConfigSpec instance
2299        to support using TPU embedding. IT IS STILL WORK IN PROGRESS, SO PLEASE
2300        DO NOT USE.
2301
2302    Raises:
2303      ValueError: `params` has reserved keys already.
2304    """
2305    if config is None or not isinstance(config, tpu_config.RunConfig):
2306      raise ValueError(
2307          '`config` must be provided with type `tpu_config.RunConfig`')
2308
2309    if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):
2310      raise ValueError('{} are reserved keys but existed in params {}.'.format(
2311          _RESERVED_PARAMS_KEYS, params))
2312
2313    if use_tpu:
2314      # Perform some very basic validations. More validations will be found in
2315      # _InternalTPUContext.
2316      if train_batch_size is None:
2317        raise ValueError('`train_batch_size` cannot be `None`')
2318      util_lib.check_positive_integer(train_batch_size, 'train_batch_size')
2319
2320      if (config.tpu_config.per_host_input_for_training is
2321          tpu_config.InputPipelineConfig.PER_SHARD_V1 and
2322          config.tpu_config.num_cores_per_replica):
2323        raise ValueError(
2324            'Model parallelism only supports per host input for training. '
2325            'Please adjust TPURunconfig.per_host_input_for_training.')
2326
2327      if eval_batch_size is not None:
2328        util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')
2329
2330      if predict_batch_size is not None:
2331        util_lib.check_positive_integer(predict_batch_size,
2332                                        'predict_batch_size')
2333
2334    # Verifies the model_fn signature according to Estimator framework.
2335    estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access
2336    # We cannot store config and params in this constructor as parent
2337    # constructor might change them, such as assigning a temp dir for
2338    # config.model_dir.
2339    model_function = self._augment_model_fn(model_fn, batch_axis)
2340
2341    # Overwrite log_step_count_steps to disable TensorLoggingHook and
2342    # StepCounterHook from being created in Estimator. TPUEstimator already
2343    # added equivalent hooks in _augment_model_fn above.
2344    self._log_every_n_steps = config.log_step_count_steps
2345    config = config.replace(log_step_count_steps=None)
2346
2347    # Passing non-None params as wrapped model_fn has it.
2348    params = params or {}
2349    super(TPUEstimator, self).__init__(
2350        model_fn=model_function,
2351        model_dir=model_dir,
2352        config=config,
2353        params=params,
2354        warm_start_from=warm_start_from)
2355    self._iterations_per_training_loop = (
2356        self._config.tpu_config.iterations_per_loop)
2357
2358    # All properties passed to _InternalTPUContext are immutable.
2359    # pylint: disable=protected-access
2360    self._ctx = tpu_context._get_tpu_context(
2361        self._config, train_batch_size, eval_batch_size, predict_batch_size,
2362        use_tpu, eval_on_tpu, experimental_embedding_config_spec)
2363
2364    self._export_to_cpu = export_to_cpu
2365    self._export_to_tpu = export_to_tpu
2366    self._experimental_exported_model_uses_all_cores = (
2367        experimental_exported_model_uses_all_cores)
2368    self._experimental_export_device_assignment = (
2369        experimental_export_device_assignment)
2370    if (experimental_exported_model_uses_all_cores and
2371        experimental_export_device_assignment):
2372      raise ValueError('experimental_exported_model_uses_all_cores and '
2373                       'experimental_export_device_assignment is not supported '
2374                       'at the same time.')
2375
2376    self._is_input_fn_invoked = None
2377    self._rendezvous = {}
2378
2379  def _add_meta_graph_for_mode(self,
2380                               builder,
2381                               input_receiver_fn_map,
2382                               checkpoint_path,
2383                               save_variables=True,
2384                               mode=model_fn_lib.ModeKeys.PREDICT,
2385                               export_tags=None,
2386                               check_variables=True):
2387    if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:
2388      logging.warning('TPUEstimator only handles mode PREDICT for exporting '
2389                      'when `export_to_tpu` is `True`; Mode {} will be ignored '
2390                      'for TPU.'.format(mode))
2391
2392    if not self._export_to_cpu and not self._export_to_tpu:
2393      raise ValueError('One of export_to_cpu and export_to_tpu must be true.')
2394
2395    if self._export_to_cpu:
2396      (super(TPUEstimator, self)._add_meta_graph_for_mode(
2397          builder,
2398          input_receiver_fn_map,
2399          checkpoint_path,
2400          save_variables,
2401          mode=mode,
2402          export_tags=export_tags,
2403          check_variables=check_variables))
2404
2405    if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT:
2406      input_receiver_fn_map = {
2407          _REWRITE_FOR_INFERENCE_MODE: input_receiver_fn_map[mode]
2408      }
2409      export_tags = [tag_constants.SERVING, tag_constants.TPU]
2410      mode = _REWRITE_FOR_INFERENCE_MODE
2411
2412      # See b/110052256 for why `check_variables` is `False`.
2413      if not self._export_to_cpu:
2414        check_variables = save_variables = True
2415      else:
2416        check_variables = save_variables = False
2417      (super(TPUEstimator, self)._add_meta_graph_for_mode(
2418          builder,
2419          input_receiver_fn_map,
2420          checkpoint_path,
2421          save_variables=save_variables,
2422          mode=mode,
2423          export_tags=export_tags,
2424          check_variables=check_variables))
2425
2426  def _call_model_fn(self, features, labels, mode, config):
2427    if mode == _REWRITE_FOR_INFERENCE_MODE:
2428      return self._call_model_fn_for_inference(features, labels, mode, config)
2429    else:
2430      return super(TPUEstimator, self)._call_model_fn(features, labels, mode,
2431                                                      config)
2432
2433  def _call_model_fn_for_inference(self, features, labels, mode, config):
2434    """Wraps `_call_model_fn` for `export_savedmodel`."""
2435    if mode != _REWRITE_FOR_INFERENCE_MODE:
2436      raise ValueError('mode must be {}; '
2437                       'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode))
2438
2439    computation, capture = self._build_computation_for_inference(
2440        features, labels, mode, config)
2441    tensors = call_computation(
2442        computation,
2443        experimental_exported_model_uses_all_cores=self
2444        ._experimental_exported_model_uses_all_cores)
2445    estimator_spec, export_outputs_dict, predictions_dict, none_indices = (
2446        capture.get())
2447    predictions_list = tensors[:len(predictions_dict)]
2448    export_outputs_list_without_none = tensors[len(predictions_dict):]
2449
2450    # Reinsert `None`s which we've taken out in
2451    # `_build_computation_for_inference()`.
2452    export_outputs_list = []
2453    while none_indices or export_outputs_list_without_none:
2454      if none_indices and none_indices[0] == len(export_outputs_list):
2455        export_outputs_list.append(None)
2456        none_indices.pop(0)
2457      else:
2458        export_outputs_list.append(export_outputs_list_without_none.pop(0))
2459
2460    # Reconstruct `export_outputs` with updated tensors.
2461    new_export_outputs_dict = nest.pack_sequence_as(export_outputs_dict,
2462                                                    export_outputs_list)
2463    export_outputs = estimator_spec.export_outputs
2464    new_export_outputs = collections.OrderedDict(
2465        (k, _clone_export_output_with_tensors(export_outputs[k], v))
2466        for k, v in six.iteritems(new_export_outputs_dict))
2467    # Reconstruct `predictions` with updated tensors.
2468    new_predictions = nest.pack_sequence_as(predictions_dict, predictions_list)
2469    if (len(new_predictions) == 1 and
2470        _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions):
2471      new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR]
2472
2473    return estimator_spec._replace(
2474        export_outputs=new_export_outputs, predictions=new_predictions)
2475
2476  def _build_computation_for_inference(self, features, labels, mode, config):
2477    capture = _CapturedObject()
2478
2479    def computation():
2480      """Computation to be passed to `TPUPartitionedCall()`."""
2481      tpu_computation, tpu_capture = self._build_tpu_computation_for_inference(
2482          features, labels, mode, config)
2483
2484      if self._experimental_export_device_assignment:
2485        # Export the device assignment as part of the model. This is useful for
2486        # model parallel usecases where the model relies on the mapping between
2487        # logical and physical devices.
2488        with self._ctx.with_mode(mode) as ctx:
2489          device_assignment = ctx.device_assignment
2490      else:
2491        device_assignment = None
2492
2493      if self._experimental_exported_model_uses_all_cores:
2494        tensors_on_cpu = tpu.rewrite(
2495            tpu_computation, device_assignment=device_assignment)
2496        tpu.prune_unconnected_ops_from_xla(ops.get_default_graph())
2497      else:
2498        tensors_on_cpu = tpu.rewrite_for_inference(
2499            tpu_computation, device_assignment=device_assignment)
2500
2501      (estimator_spec, export_outputs_dict, export_outputs_list,
2502       predictions_dict) = (
2503           tpu_capture.get())
2504      predictions_list = tensors_on_cpu[:len(predictions_dict)]
2505      export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):]
2506
2507      # Reconstruct tensors used in export_outputs, with TPU tensors replaced
2508      # with their CPU counterpart returned from `rewrite_for_inference()`.
2509      # `function.Defun()` does not like `None`s in return values, so we leave
2510      # `None`s out but record their positions for later reconstruction.
2511      export_outputs_list_without_none = []
2512      none_indices = []
2513      for i, t in enumerate(export_outputs_list):
2514        if t is None:
2515          none_indices.append(i)
2516        else:
2517          export_outputs_list_without_none.append(
2518              export_outputs_tpu_on_cpu_list.pop(0))
2519
2520      capture.capture((estimator_spec, export_outputs_dict, predictions_dict,
2521                       none_indices))
2522      return predictions_list + export_outputs_list_without_none
2523
2524    return computation, capture
2525
2526  def _build_tpu_computation_for_inference(self, features, labels, mode,
2527                                           config):
2528    capture = _CapturedObject()
2529
2530    def computation():
2531      """Compute tpu tensors used in export_outputs.
2532
2533      Passed to rewrite_for_inference so that model_fn will be called under
2534      the rewriting contexts. Only tpu tensors are returned, but export_outputs
2535      and scaffold are captured.
2536
2537      Returns:
2538         A list of Tensors used in export_outputs and not marked for
2539         outside_compilation.
2540      """
2541      # We should only call model fn once and it should be inside `computation`
2542      # so that building the graph will happen under `rewrite_for_inference`.
2543      estimator_spec = super(TPUEstimator, self)._call_model_fn(
2544          features, labels, mode, config)
2545
2546      # We pick the TPU tensors out from `export_output` and later return them
2547      # from `computation` for rewriting.
2548      export_outputs_dict = collections.OrderedDict(
2549          (k, _export_output_to_tensors(v))
2550          for k, v in six.iteritems(estimator_spec.export_outputs))
2551      export_outputs_list = nest.flatten(export_outputs_dict)
2552      export_outputs_tpu_list = [
2553          t for t in export_outputs_list if t is not None
2554      ]
2555
2556      if isinstance(estimator_spec.predictions, dict):
2557        predictions_dict = collections.OrderedDict(
2558            (k, v) for k, v in six.iteritems(estimator_spec.predictions))
2559      else:
2560        predictions_dict = {
2561            _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions
2562        }
2563      predictions_list = nest.flatten(predictions_dict)
2564
2565      # We cannot return everything we want through the return values, so
2566      # capture the rest here for later use.
2567      capture.capture((estimator_spec, export_outputs_dict, export_outputs_list,
2568                       predictions_dict))
2569      return predictions_list + export_outputs_tpu_list
2570
2571    return computation, capture
2572
2573  def _create_global_step(self, graph):
2574    """Creates a global step suitable for TPUs.
2575
2576    Args:
2577      graph: The graph in which to create the global step.
2578
2579    Returns:
2580      A global step `Tensor`.
2581
2582    Raises:
2583      ValueError: if the global step tensor is already defined.
2584    """
2585    return _create_global_step(graph)
2586
2587  def _convert_train_steps_to_hooks(self, steps, max_steps):
2588    with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx:
2589      if ctx.is_running_on_cpu():
2590        return super(TPUEstimator, self)._convert_train_steps_to_hooks(
2591            steps, max_steps)
2592
2593    # On TPU.
2594    if steps is None and max_steps is None:
2595      raise ValueError(
2596          'For TPU training, one of `steps` or `max_steps` must be set. '
2597          'Cannot be both `None`.')
2598
2599    # Estimator.train has explicit positiveness check.
2600    if steps is not None:
2601      util_lib.check_positive_integer(steps, 'Train steps')
2602    if max_steps is not None:
2603      util_lib.check_positive_integer(max_steps, 'Train max_steps')
2604
2605    return [
2606        _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)
2607    ]
2608
2609  def _convert_eval_steps_to_hooks(self, steps):
2610    with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:
2611      if ctx.is_running_on_cpu():
2612        return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps)
2613
2614    if steps is None:
2615      raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.')
2616
2617    util_lib.check_positive_integer(steps, 'Eval steps')
2618
2619    return [
2620        evaluation._StopAfterNEvalsHook(  # pylint: disable=protected-access
2621            num_evals=steps),
2622        _SetEvalIterationsHook(steps)
2623    ]
2624
2625  def _call_input_fn(self, input_fn, mode):
2626    """Calls the input function.
2627
2628    Args:
2629      input_fn: The input function.
2630      mode: ModeKeys
2631
2632    Returns:
2633      In TPU mode, returns an input_fn to be called later in model_fn.
2634      Otherwise, calls the input_fn and returns either fatures or
2635        (features, labels).
2636
2637    Raises:
2638      ValueError: if input_fn takes invalid arguments or does not have `params`.
2639    """
2640    input_fn_args = function_utils.fn_args(input_fn)
2641    config = self.config  # a deep copy.
2642    kwargs = {}
2643    if 'params' in input_fn_args:
2644      kwargs['params'] = self.params  # a deep copy.
2645    else:
2646      raise ValueError('input_fn ({}) does not include params argument, '
2647                       'required by TPUEstimator to pass batch size as '
2648                       'params["batch_size"]'.format(input_fn))
2649    if 'config' in input_fn_args:
2650      kwargs['config'] = config
2651
2652    if 'mode' in input_fn_args:
2653      kwargs['mode'] = mode
2654
2655    # Records the fact input_fn has been invoked.
2656    self._is_input_fn_invoked = True
2657
2658    with self._ctx.with_mode(mode) as ctx:
2659      # Setting the batch size in params first. This helps user to have same
2660      # input_fn for use_tpu=True/False.
2661      batch_size_for_input_fn = ctx.batch_size_for_input_fn
2662      if batch_size_for_input_fn is not None:
2663        _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,
2664                            batch_size_for_input_fn)
2665
2666      # For export_savedmodel, input_fn is never passed to Estimator. So,
2667      # `is_export_mode` must be False.
2668      if ctx.is_running_on_cpu(is_export_mode=False):
2669        with ops.device('/device:CPU:0'):
2670          return input_fn(**kwargs)
2671
2672      # For TPU computation, input_fn should be invoked in a tf.while_loop for
2673      # performance. While constructing the tf.while_loop, the structure of
2674      # inputs returned by the `input_fn` needs to be recorded. The structure
2675      # includes whether features or labels is dict or single Tensor, dict keys,
2676      # tensor shapes, and dtypes. The recorded structure is used to create the
2677      # infeed dequeue ops, which must be wrapped and passed as a Fn, called
2678      # inside the TPU computation, as the TPU computation is wrapped inside a
2679      # tf.while_loop also. So, we either pass input_fn to model_fn or pass
2680      # dequeue_fn to model_fn. Here, `input_fn` is passed directly as
2681      # `features` in `model_fn` signature.
2682      def _input_fn(ctx):
2683        _add_item_to_params(kwargs['params'], _CTX_KEY, ctx)
2684        return input_fn(**kwargs)
2685
2686      return _input_fn
2687
2688  def _validate_features_in_predict_input(self, result):
2689    """Skip the validation.
2690
2691    For TPUEstimator, we do not need to check the result type. `_InputPipeline`
2692    has stronger check. Parent class's check generates confusing warning msg.
2693
2694    Args:
2695      result: `features` returned by input_fn.
2696    """
2697    pass
2698
2699  def train(self,
2700            input_fn,
2701            hooks=None,
2702            steps=None,
2703            max_steps=None,
2704            saving_listeners=None):
2705    rendezvous = error_handling.ErrorRendezvous(num_sources=3)
2706    self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous
2707    try:
2708      return super(TPUEstimator, self).train(
2709          input_fn=input_fn,
2710          hooks=hooks,
2711          steps=steps,
2712          max_steps=max_steps,
2713          saving_listeners=saving_listeners)
2714    except Exception:  # pylint: disable=broad-except
2715      rendezvous.record_error('training_loop', sys.exc_info())
2716    finally:
2717      rendezvous.record_done('training_loop')
2718      rendezvous.raise_errors()
2719
2720  def evaluate(self,
2721               input_fn,
2722               steps=None,
2723               hooks=None,
2724               checkpoint_path=None,
2725               name=None):
2726    rendezvous = error_handling.ErrorRendezvous(num_sources=3)
2727    self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous
2728    try:
2729      return super(TPUEstimator, self).evaluate(
2730          input_fn,
2731          steps=steps,
2732          hooks=hooks,
2733          checkpoint_path=checkpoint_path,
2734          name=name)
2735    except Exception:  # pylint: disable=broad-except
2736      rendezvous.record_error('evaluation_loop', sys.exc_info())
2737    finally:
2738      rendezvous.record_done('evaluation_loop')
2739      rendezvous.raise_errors()
2740
2741  def predict(self,
2742              input_fn,
2743              predict_keys=None,
2744              hooks=None,
2745              checkpoint_path=None,
2746              yield_single_examples=True):
2747    rendezvous = error_handling.ErrorRendezvous(num_sources=3)
2748    self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous
2749    try:
2750      for result in super(TPUEstimator, self).predict(
2751          input_fn=input_fn,
2752          predict_keys=predict_keys,
2753          hooks=hooks,
2754          checkpoint_path=checkpoint_path,
2755          yield_single_examples=yield_single_examples):
2756        yield result
2757    except Exception:  # pylint: disable=broad-except
2758      rendezvous.record_error('prediction_loop', sys.exc_info())
2759    finally:
2760      rendezvous.record_done('prediction_loop')
2761      rendezvous.raise_errors()
2762
2763    rendezvous.record_done('prediction_loop')
2764    rendezvous.raise_errors()
2765
2766  def _augment_model_fn(self, model_fn, batch_axis):
2767    """Returns a new model_fn, which wraps the TPU support."""
2768
2769    def _model_fn(features, labels, mode, config, params):
2770      """A Estimator `model_fn` for TPUEstimator."""
2771
2772      # `input_fn` is called in `train()`, `evaluate()`, and `predict()`,
2773      # but not in `export_savedmodel()`.
2774      if self._is_input_fn_invoked:
2775        is_export_mode = False
2776      else:
2777        is_export_mode = True
2778
2779      # Clear the bit.
2780      self._is_input_fn_invoked = None
2781
2782      if is_export_mode:
2783        if mode == _REWRITE_FOR_INFERENCE_MODE:
2784          _add_item_to_params(params, _USE_TPU_KEY, True)
2785          mode = model_fn_lib.ModeKeys.PREDICT
2786        else:
2787          _add_item_to_params(params, _USE_TPU_KEY, False)
2788
2789      with self._ctx.with_mode(mode) as ctx:
2790        model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)
2791
2792        # examples_hook is added to training_hooks for both CPU and TPU
2793        # execution.
2794        if self._log_every_n_steps is not None:
2795          examples_hook = ExamplesPerSecondHook(
2796              ctx.global_batch_size,
2797              # pylint:disable=g-long-ternary
2798              output_dir=(self.model_dir
2799                          if not config or config.save_summary_steps
2800                          else None),
2801              # pylint:enable=g-long-ternary
2802              every_n_steps=self._log_every_n_steps)
2803
2804        if ctx.is_running_on_cpu(is_export_mode=is_export_mode):
2805          logging.info('Running %s on CPU', mode)
2806          estimator_spec = model_fn_wrapper.call_without_tpu(
2807              features, labels, is_export_mode=is_export_mode)
2808          if self._log_every_n_steps is not None:
2809            estimator_spec = estimator_spec._replace(
2810                training_hooks=estimator_spec.training_hooks + (examples_hook,))
2811          return estimator_spec
2812
2813        assert labels is None, '`labels` passed to `model_fn` must be `None`.'
2814        # TPUEstimator._call_input_fn passes `input_fn` as features to here.
2815        assert callable(features), '`input_fn` is not callable.'
2816        input_fn = features
2817
2818        tpu_init_ops = []
2819        if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN:
2820          dummy_table_variables, dummy_table_variables_init = (
2821              tpu_embedding_gradient.create_dummy_table_variables(
2822                  ctx.embedding_config.tpu_embedding))
2823          ctx.embedding_config.dummy_table_variables = dummy_table_variables
2824          tpu_init_ops.append(dummy_table_variables_init)
2825
2826        input_holders = _InputPipeline(input_fn, batch_axis, ctx)
2827        enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (
2828            input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())
2829
2830        graph = ops.get_default_graph()
2831        for enqueue_op in enqueue_ops:
2832          if isinstance(enqueue_op, list):
2833            graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)
2834          else:
2835            graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)
2836
2837        if mode == model_fn_lib.ModeKeys.TRAIN:
2838          compile_op, loss, host_call, scaffold, training_hooks = (
2839              _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
2840          if ctx.embedding_config:
2841            g = ops.get_default_graph()
2842            table_to_config_dict = (
2843                ctx.embedding_config.tpu_embedding.table_to_config_dict)
2844            optimization_parameters = (
2845                ctx.embedding_config.tpu_embedding.optimization_parameters)
2846            embedding_variable_name_by_table, slot_variable_names_by_table = (
2847                _tpu_estimator_embedding.get_full_variable_names(
2848                    g, table_to_config_dict, optimization_parameters
2849                )
2850            )
2851            embedding_variables_and_ops = (
2852                ctx.embedding_config.tpu_embedding.create_variables_and_ops(
2853                    embedding_variable_name_by_table,
2854                    slot_variable_names_by_table
2855                ))
2856            tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
2857
2858          host_ops = host_call.create_tpu_hostcall()
2859          if host_ops is None:
2860            host_ops = []
2861
2862          shutdown_hooks = []
2863          shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',
2864                                         'shutdown_worker')
2865          if shutdown_mode:
2866            if shutdown_mode == 'shutdown_worker':
2867              finalizer_hooks = [
2868                  session_support.ShutdownLameWorkers(timeout_ms=60 * 1000),
2869              ]
2870            elif shutdown_mode == 'shutdown_computation':
2871              finalizer_hooks = [
2872                  session_support.RestartComputation(timeout_ms=60 * 1000),
2873              ]
2874            else:
2875              raise ValueError(
2876                  'Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % shutdown_mode)
2877
2878            shutdown_hooks.append(
2879                session_support.GracefulShutdownHook(
2880                    checkpoint_prefix=self.model_dir + '/model.ckpt',
2881                    on_shutdown_hooks=finalizer_hooks))
2882
2883          with ops.control_dependencies([loss]):
2884            global_step = array_ops.identity(training.get_global_step())
2885          hooks = input_hooks + shutdown_hooks
2886          hooks.extend([
2887              TPUInfeedOutfeedSessionHook(
2888                  ctx,
2889                  enqueue_ops,
2890                  host_ops,
2891                  tpu_compile_op=compile_op,
2892                  run_infeed_loop_on_coordinator=(
2893                      run_infeed_loop_on_coordinator),
2894                  rendezvous=self._rendezvous[mode],
2895                  master=self._config.master,
2896                  session_config=self._session_config,
2897                  tpu_init_ops=tpu_init_ops),
2898              InstallSignalHandlerHook()
2899          ])
2900          if self._log_every_n_steps is not None:
2901            logging_hook_frequency = (  # Divide and round up
2902                (self._log_every_n_steps +
2903                 self._config.tpu_config.iterations_per_loop - 1) //
2904                self._config.tpu_config.iterations_per_loop)
2905            hooks.append(
2906                training.LoggingTensorHook({
2907                    'loss': array_ops.identity(loss),
2908                    'step': global_step,
2909                },
2910                                           every_n_iter=logging_hook_frequency))
2911            examples_hook._set_steps_per_run(  # pylint: disable=protected-access
2912                self._config.tpu_config.iterations_per_loop)
2913            hooks.append(examples_hook)
2914
2915          if training_hooks:
2916            hooks.extend(training_hooks)
2917
2918          chief_hooks = []
2919          if (self._config.save_checkpoints_secs or
2920              self._config.save_checkpoints_steps):
2921            checkpoint_hook = training.CheckpointSaverHook(
2922                self.model_dir,
2923                save_secs=self._config.save_checkpoints_secs,
2924                save_steps=self._config.save_checkpoints_steps,
2925                scaffold=scaffold)
2926            checkpoint_hook._set_steps_per_run(  # pylint: disable=protected-access
2927                self._config.tpu_config.iterations_per_loop)
2928            chief_hooks.append(checkpoint_hook)
2929
2930          summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
2931          with ops.control_dependencies([loss]):
2932            update_ops = _sync_variables_ops(ctx)
2933            if ctx.embedding_config:
2934              update_ops.extend(embedding_variables_and_ops.retrieve_ops())
2935
2936          # Validate the TPU training graph to catch basic errors
2937          _validate_tpu_training_graph()
2938
2939          train_op = control_flow_ops.group(*update_ops)
2940          graph.add_to_collection(_TPU_TRAIN_OP, train_op)
2941
2942          return model_fn_lib.EstimatorSpec(
2943              mode,
2944              loss=loss,
2945              training_chief_hooks=chief_hooks,
2946              training_hooks=hooks,
2947              train_op=train_op,
2948              scaffold=scaffold)
2949
2950        if mode == model_fn_lib.ModeKeys.EVAL:
2951          compile_op, total_loss, host_calls, scaffold, eval_hooks = (
2952              _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))
2953          if ctx.embedding_config:
2954            g = ops.get_default_graph()
2955            table_to_config_dict = (
2956                ctx.embedding_config.tpu_embedding.table_to_config_dict)
2957            embedding_variable_name_by_table, _ = (
2958                _tpu_estimator_embedding.get_full_variable_names(
2959                    g, table_to_config_dict)
2960            )
2961            embedding_variables_and_ops = (
2962                ctx.embedding_config.tpu_embedding.create_variables_and_ops(
2963                    embedding_variable_name_by_table
2964                ))
2965            tpu_init_ops.extend(embedding_variables_and_ops.load_ops())
2966          iterations_per_loop_var = _create_or_get_iterations_per_loop()
2967          mean_loss = math_ops.div(
2968              total_loss,
2969              math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype))
2970
2971          with ops.control_dependencies([mean_loss]):
2972            # After TPU evaluation computation is done (the mean_loss tensor),
2973            # reads all variables back from TPU and updates the eval step
2974            # counter properly
2975            internal_ops_to_run = _sync_variables_ops(ctx)
2976            internal_ops_to_run.append(
2977                _increase_eval_step_op(iterations_per_loop_var))
2978
2979          host_call_ret = host_calls.create_tpu_hostcall()
2980          eval_metric_ops = {}
2981          eval_update_ops = []
2982
2983          eval_metrics = host_call_ret.get('eval_metrics', {})
2984          if eval_metrics:
2985            # Creates a dummy metric update_op for all metrics. Estimator
2986            # expects all metrics in `eval_metric_ops` have update_op and calls
2987            # them one by one. The real metric update_ops are invoked in a
2988            # separated thread. So, here give Estimator the dummy op for all
2989            # metrics.
2990            with ops.control_dependencies(internal_ops_to_run):
2991              dummy_update_op = control_flow_ops.no_op()
2992
2993            for k, v in eval_metrics.items():
2994              eval_metric_ops[k] = (v[0], dummy_update_op)
2995              eval_update_ops.append(v[1])
2996          else:
2997            # If no eval metrics are passed, create an identity node for the
2998            # loss and add `internal_ops_to_run` to its dependencies. So
2999            # `internal_ops_to_run` can be executed.
3000            with ops.control_dependencies(internal_ops_to_run):
3001              mean_loss = array_ops.identity(mean_loss)
3002
3003          if 'host_call' not in host_call_ret:
3004            host_ops = []
3005          else:
3006            host_ops = host_call_ret['host_call']
3007          hooks = [
3008              TPUInfeedOutfeedSessionHook(
3009                  ctx,
3010                  enqueue_ops,
3011                  eval_update_ops + host_ops,
3012                  tpu_compile_op=compile_op,
3013                  run_infeed_loop_on_coordinator=(
3014                      run_infeed_loop_on_coordinator),
3015                  rendezvous=self._rendezvous[mode],
3016                  master=self._config.evaluation_master,
3017                  session_config=self._session_config,
3018                  tpu_init_ops=tpu_init_ops)
3019          ] + input_hooks
3020
3021          if eval_hooks:
3022            hooks.extend(eval_hooks)
3023
3024          return model_fn_lib.EstimatorSpec(
3025              mode,
3026              loss=mean_loss,
3027              evaluation_hooks=hooks,
3028              eval_metric_ops=eval_metric_ops,
3029              scaffold=scaffold)
3030
3031        # Predict
3032        assert mode == model_fn_lib.ModeKeys.PREDICT
3033
3034        (compile_op, dummy_predict_op, host_calls,
3035         scaffold, prediction_hooks) = _predict_on_tpu_system(
3036             ctx, model_fn_wrapper, dequeue_fn)
3037        with ops.control_dependencies([dummy_predict_op]):
3038          internal_ops_to_run = _sync_variables_ops(ctx)
3039          with ops.control_dependencies(internal_ops_to_run):
3040            dummy_predict_op = control_flow_ops.no_op()
3041
3042        # In train and evaluation, the main TPU program is passed to monitored
3043        # training session to run. Infeed enqueue and outfeed dequeue are
3044        # executed in side threads. This is not the configuration for
3045        # prediction mode.
3046        #
3047        # For prediction, the Estimator executes the EstimatorSpec.predictions
3048        # directly and yield the element (via generator) to call site. So, the
3049        # outfeed based prediction must be passed to MonitoredSession directly.
3050        # Other parts of the TPU execution are organized as follows.
3051        #
3052        # 1. All outfeed based Tensors must be grouped with predictions Tensors
3053        #    to form a single invocation. This avoid the issue we might trigger
3054        #    multiple outfeeds incorrectly. To achieve this, `host_call` is
3055        #    placed in control_dependencies of `stopping_signals`, and
3056        #    `stopping_signals` is passed into _StoppingPredictHook, which sets
3057        #    the `stopping_signals` as SessionRunArgs. MonitoredSession merges
3058        #    all SessionRunArgs with the fetch in session.run together.
3059        #
3060        # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)
3061        #    are grouped together. They will be launched once and only once in
3062        #    side threads and they quit naturally according to the SAME stopping
3063        #    condition.
3064        enqueue_ops.append(dummy_predict_op)
3065
3066        host_call_ret = host_calls.create_tpu_hostcall()
3067        if 'host_call' not in host_call_ret:
3068          host_ops = []
3069        else:
3070          host_ops = host_call_ret['host_call']
3071
3072        predictions = host_call_ret['predictions']
3073        _verify_cross_hosts_transfer_size(
3074            predictions,
3075            message=(
3076                'The estimated size for TPUEstimatorSpec.predictions is too '
3077                'large.'))
3078        signals = host_call_ret['signals']
3079
3080        with ops.control_dependencies(host_ops):
3081          host_ops = []  # Empty, we do do not need it anymore.
3082          scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(
3083              signals)
3084          predictions = _PaddingSignals.slice_tensor_or_dict(
3085              predictions, signals)
3086
3087        hooks = [
3088            _StoppingPredictHook(scalar_stopping_signal),
3089            TPUInfeedOutfeedSessionHookForPrediction(
3090                ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode],
3091                tpu_compile_op=compile_op,
3092                master=self._config.master,
3093                session_config=self._session_config),
3094        ] + input_hooks
3095
3096        if prediction_hooks:
3097          hooks.extend(prediction_hooks)
3098
3099        return model_fn_lib.EstimatorSpec(
3100            mode,
3101            prediction_hooks=hooks,
3102            predictions=predictions,
3103            scaffold=scaffold)
3104
3105    return _model_fn
3106
3107
3108def _export_output_to_tensors(export_output):
3109  """Get a list of `Tensors` used in `export_output`.
3110
3111  Args:
3112    export_output: an `ExportOutput` object such as `ClassificationOutput`,
3113      `RegressionOutput`, or `PredictOutput`.
3114
3115  Returns:
3116    a list of tensors used in export_output.
3117
3118  Raises:
3119    ValueError: if `export_output` is not one of `ClassificationOutput`,
3120        `RegressionOutput`, or `PredictOutput`.
3121  """
3122  if isinstance(export_output, export_output_lib.ClassificationOutput):
3123    return [export_output.scores, export_output.classes]
3124  elif isinstance(export_output, export_output_lib.RegressionOutput):
3125    return [export_output.value]
3126  elif isinstance(export_output, export_output_lib.PredictOutput):
3127    return list(export_output.outputs.values())
3128  else:
3129    raise ValueError(
3130        '`export_output` must be have type `ClassificationOutput`, '
3131        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
3132
3133
3134def _clone_export_output_with_tensors(export_output, tensors):
3135  """Clones `export_output` but with new `tensors`.
3136
3137  Args:
3138    export_output: an `ExportOutput` object such as `ClassificationOutput`,
3139      `RegressionOutput`, or `PredictOutput`.
3140    tensors: a list of `Tensors` used to construct a new `export_output`.
3141
3142  Returns:
3143    A dict similar to `export_output` but with `tensors`.
3144
3145  Raises:
3146    ValueError: if `export_output` is not one of `ClassificationOutput`,
3147        `RegressionOutput`, or `PredictOutput`.
3148  """
3149  if isinstance(export_output, export_output_lib.ClassificationOutput):
3150    if len(tensors) != 2:
3151      raise ValueError('tensors must be of length 2; '
3152                       'got {}.'.format(len(tensors)))
3153    return export_output_lib.ClassificationOutput(*tensors)
3154  elif isinstance(export_output, export_output_lib.RegressionOutput):
3155    if len(tensors) != 1:
3156      raise ValueError('tensors must be of length 1; '
3157                       'got {}'.format(len(tensors)))
3158    return export_output_lib.RegressionOutput(*tensors)
3159  elif isinstance(export_output, export_output_lib.PredictOutput):
3160    return export_output_lib.PredictOutput(
3161        dict(zip(export_output.outputs.keys(), tensors)))
3162  else:
3163    raise ValueError(
3164        '`export_output` must be have type `ClassificationOutput`, '
3165        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))
3166
3167
3168def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
3169  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
3170  iterations_per_loop_var = _create_or_get_iterations_per_loop()
3171
3172  (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks
3173  ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)
3174
3175  @tpu_function.on_device_training_loop
3176  def multi_tpu_eval_steps_on_single_shard():
3177    return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step,
3178                                [_ZERO_LOSS])
3179
3180  (compile_op, loss,) = tpu.split_compile_and_shard(
3181      multi_tpu_eval_steps_on_single_shard,
3182      inputs=[],
3183      num_shards=ctx.num_replicas,
3184      outputs_from_all_shards=False,
3185      device_assignment=ctx.device_assignment)
3186
3187  loss = loss[0]
3188  scaffold = _get_scaffold(captured_scaffold_fn)
3189  return compile_op, loss, host_calls, scaffold, captured_eval_hooks.get()
3190
3191
3192def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
3193  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
3194  iterations_per_loop_var = _create_or_get_iterations_per_loop()
3195
3196  (single_tpu_train_step, host_call, captured_scaffold_fn,
3197   captured_training_hooks) = (
3198       model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))
3199
3200  @tpu_function.on_device_training_loop
3201  def multi_tpu_train_steps_on_single_shard():
3202    return training_loop.repeat(iterations_per_loop_var, single_tpu_train_step,
3203                                [_INITIAL_LOSS])
3204
3205  (compile_op, loss,) = tpu.split_compile_and_shard(
3206      multi_tpu_train_steps_on_single_shard,
3207      inputs=[],
3208      num_shards=ctx.num_replicas,
3209      outputs_from_all_shards=False,
3210      device_assignment=ctx.device_assignment)
3211
3212  loss = loss[0]
3213  scaffold = _get_scaffold(captured_scaffold_fn)
3214  return compile_op, loss, host_call, scaffold, captured_training_hooks.get()
3215
3216
3217def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
3218  """Executes `model_fn_wrapper` multiple times on all TPU shards."""
3219  (single_tpu_predict_step, host_calls, captured_scaffold_fn,
3220   captured_predict_hooks
3221  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
3222
3223  @tpu_function.on_device_training_loop
3224  def multi_tpu_predict_steps_on_single_shard():
3225
3226    def cond(scalar_stopping_signal):
3227      return math_ops.logical_not(
3228          _StopSignals.should_stop(scalar_stopping_signal))
3229
3230    inputs = [_StopSignals.NON_STOPPING_SIGNAL]
3231    outputs = training_loop.while_loop(
3232        cond, single_tpu_predict_step, inputs=inputs, name=b'loop')
3233    return outputs
3234
3235  (compile_op, dummy_predict_op,) = tpu.split_compile_and_shard(
3236      multi_tpu_predict_steps_on_single_shard,
3237      inputs=[],
3238      num_shards=ctx.num_replicas,
3239      outputs_from_all_shards=False,
3240      device_assignment=ctx.device_assignment)
3241
3242  dummy_predict_op = dummy_predict_op[0]
3243  scaffold = _get_scaffold(captured_scaffold_fn)
3244  return (compile_op, dummy_predict_op, host_calls, scaffold,
3245          captured_predict_hooks.get())
3246
3247
3248def _wrap_computation_in_while_loop(device, op_fn):
3249  """Wraps the ops generated by `op_fn` in tf.while_loop."""
3250
3251  def computation(i):
3252    with ops.control_dependencies(op_fn()):
3253      return i + 1
3254
3255  iterations_per_loop_var = _create_or_get_iterations_per_loop()
3256  # By setting parallel_iterations=1, the parallel execution in while_loop is
3257  # basically turned off.
3258  with ops.device(device):
3259    iterations = array_ops.identity(iterations_per_loop_var)
3260    return control_flow_ops.while_loop(
3261        lambda i: i < iterations,
3262        computation, [constant_op.constant(0)],
3263        parallel_iterations=1)
3264
3265
3266def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
3267  """Wraps the ops generated by `op_fn` in tf.while_loop."""
3268
3269  def cond(scalar_stopping_signal):
3270    return math_ops.logical_not(
3271        _StopSignals.should_stop(scalar_stopping_signal))
3272
3273  def computation(unused_scalar_stopping_signal):
3274    return_value = op_fn()
3275    execute_ops = return_value['ops']
3276    signals = return_value['signals']
3277    with ops.control_dependencies(execute_ops):
3278      return _StopSignals.as_scalar_stopping_signal(signals)
3279
3280  # By setting parallel_iterations=1, the parallel execution in while_loop is
3281  # basically turned off.
3282  with ops.device(device):
3283    return control_flow_ops.while_loop(
3284        cond,
3285        computation, [_StopSignals.NON_STOPPING_SIGNAL],
3286        parallel_iterations=1)
3287
3288
3289def _validate_tpu_training_graph():
3290  """Validate graph before running distributed training.
3291
3292  Raises:
3293    ValueError: If the graph seems invalid for running on device
3294  """
3295  operations = ops.get_default_graph().get_operations()
3296
3297  # Check if there is atleast one CrossReplicaSum operation in the graph
3298  # This should be introduced by using the CrossShardOptimizer wrapper
3299  cross_replica_sum_ops = [
3300      o for o in operations if o.type == _CROSS_REPLICA_SUM_OP
3301  ]
3302  if not cross_replica_sum_ops:
3303    raise ValueError(
3304        'CrossShardOptimizer must be used for model training on TPUs.')
3305
3306
3307class _CapturedObject(object):
3308  """A placeholder to capture an object.
3309
3310  This is useful when we need to capture a Python object in the Tensorflow
3311  control flow body function and use it outside the control flow.
3312  """
3313
3314  def __init__(self):
3315    self._object = None
3316    self._captured = False
3317
3318  def capture(self, o):
3319    if self._captured:
3320      raise RuntimeError(
3321          'InternalError: Object can capture only once. Please file bug.')
3322
3323    self._captured = True
3324    self._object = o
3325
3326  def get(self):
3327    if not self._captured:
3328      raise RuntimeError(
3329          'InternalError: Object is not captured properly before `get`. '
3330          'Please file bug.')
3331    return self._object
3332
3333
3334def _get_scaffold(captured_scaffold_fn):
3335  """Retrieves the Scaffold from `captured_scaffold_fn`."""
3336  with _CapturingContext(message='Inside scaffold_fn'):
3337    scaffold_fn = captured_scaffold_fn.get()
3338    if scaffold_fn:
3339      scaffold = scaffold_fn()
3340      if scaffold is None:
3341        raise ValueError(
3342            'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
3343    else:
3344      scaffold = None
3345
3346  if scaffold:
3347    wrapped_finalize = scaffold.finalize
3348
3349    def _finalize():
3350      with _CapturingContext('Inside Scaffold.finalize'):
3351        wrapped_finalize()
3352
3353    scaffold.finalize = _finalize
3354  return scaffold
3355
3356
3357class _CapturingContext(control_flow_ops.ControlFlowContext):
3358  """Tracks references to Tensors defined in TPU replication."""
3359
3360  def __init__(self, message):
3361    control_flow_ops.ControlFlowContext.__init__(self)
3362    self._message = message
3363
3364  def to_control_flow_context_def(self, context_def, export_scope=None):
3365    # pylint: disable=useless-super-delegation
3366    # NOTE(slebedev): the method is required by `ControlFlowContext`.
3367    super(_CapturingContext, self).to_control_flow_context_def(
3368        context_def, export_scope)
3369
3370  def AddOp(self, op):  # pylint: disable=invalid-name
3371    for c in op.inputs:
3372      if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr:  # pylint: disable=protected-access
3373        raise ValueError('{}: Op {} depends on TPU computation {}, '
3374                         'which is not allowed.'.format(self._message, op, c))
3375
3376  def __enter__(self):
3377    # pylint: disable=protected-access
3378    self._g = ops.get_default_graph()
3379    self._old = self._g._get_control_flow_context()
3380    self._g._set_control_flow_context(self)
3381    # pylint: enable=protected-access
3382
3383  def __exit__(self, _, __, ___):  # pylint: disable=invalid-name
3384    self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access
3385
3386
3387class _Inputs(object):
3388  """A data structure representing the input_fn returned values.
3389
3390  This also supports the returned value from input_fn as `Dataset`.
3391  """
3392
3393  def __init__(self, features=None, labels=None, dataset=None, signals=None):
3394    if dataset is not None and (features is not None or labels is not None or
3395                                signals is not None):
3396      raise RuntimeError('Internal Error: Either (features and labels) or '
3397                         'dataset should be provided, not both. Please file '
3398                         'bug')
3399
3400    self._features = features
3401    self._labels = labels
3402    self._signals = signals
3403
3404    self._dataset = dataset
3405    self._iterator = None
3406
3407  @staticmethod
3408  def from_input_fn(return_values):
3409    """Returns an `_Inputs` instance according to `input_fn` return value."""
3410    if isinstance(return_values, dataset_ops.DatasetV2):
3411      dataset = return_values
3412      return _Inputs(dataset=dataset)
3413
3414    features, labels = _Inputs._parse_inputs(return_values)
3415    return _Inputs(features, labels)
3416
3417  @staticmethod
3418  def _parse_inputs(return_values):
3419    if isinstance(return_values, tuple):
3420      features, labels = return_values
3421    else:
3422      features, labels = return_values, None
3423    return features, labels
3424
3425  @property
3426  def is_dataset(self):
3427    """Returns True if the return value from input_fn is Dataset."""
3428    return self._dataset is not None
3429
3430  def dataset_initializer(self):
3431    """Returns the dataset's initializer.
3432
3433    The initializer must be run before calling `features_and_labels`.
3434    """
3435    self._iterator = dataset_ops.make_initializable_iterator(self._dataset)
3436    return self._iterator.initializer
3437
3438  def features_and_labels(self):
3439    """Gets `features` and `labels`."""
3440    if self.is_dataset:
3441      if self._iterator is None:
3442        raise RuntimeError('Internal error: Must run dataset_initializer '
3443                           'before calling features_and_labels(). Please file '
3444                           'a bug!')
3445      return _Inputs._parse_inputs(self._iterator.get_next())
3446
3447    return (self._features, self._labels)
3448
3449  def signals(self):
3450    return self._signals
3451
3452  @property
3453  def dataset(self):
3454    return self._dataset
3455
3456
3457class _InputsWithStoppingSignals(_Inputs):
3458  """Inputs with `_StopSignals` inserted into the dataset."""
3459
3460  def __init__(self,
3461               dataset,
3462               batch_size,
3463               add_padding=False,
3464               num_invocations_per_step=1):
3465
3466    assert dataset is not None
3467    user_provided_dataset = dataset.map(
3468        _InputsWithStoppingSignals.insert_stopping_signal(
3469            stop=False, batch_size=batch_size, add_padding=add_padding))
3470    if num_invocations_per_step == 1:
3471      final_batch_dataset = dataset.take(1).map(
3472          _InputsWithStoppingSignals.insert_stopping_signal(
3473              stop=True, batch_size=batch_size, add_padding=add_padding))
3474    else:
3475      # We append (2 * num_invocations_per_step - 1) batches for exhausting the
3476      # user_provided_dataset and stop properly.
3477      # For example, if num_invocations_per_step is 2, we append 3 additional
3478      # padding batches: b1, b2, b3.
3479      # If user_provided_dataset contains two batches: a1, a2
3480      # Step 1: [a1, a2]
3481      # Step 2: [b1, b2] -> STOP
3482      # If user_provided_dataset contains three batches: a1, a2, a3.
3483      # The training loops:
3484      # Step 1: [a1, a2]
3485      # Step 2: [a3, b1]
3486      # Step 3: [b2, b3] -> STOP.
3487      final_batch_dataset = dataset.take(1).map(
3488          _InputsWithStoppingSignals.insert_stopping_signal(
3489              stop=True, batch_size=batch_size, add_padding=add_padding))
3490      final_batch_dataset = final_batch_dataset.repeat(
3491          2 * num_invocations_per_step - 1)
3492
3493      def _set_mask(data_dict):
3494        signals = data_dict['signals']
3495        signals['padding_mask'] = array_ops.ones_like(signals['padding_mask'])
3496        data_dict['signals'] = signals
3497        return data_dict
3498
3499      # Mask out the extra batch.
3500      final_batch_dataset = final_batch_dataset.map(_set_mask)
3501
3502    dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)
3503
3504    super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)
3505    self._current_inputs = None
3506
3507  def features_and_labels(self):
3508    if self._current_inputs is not None:
3509      raise RuntimeError(
3510          'Internal Error: The previous inputs have not been properly '
3511          'consumed. First call features_and_labels, then call signals.')
3512
3513    inputs_with_signals = self._iterator.get_next()
3514    features = inputs_with_signals['features']
3515    labels = inputs_with_signals.get('labels')
3516
3517    self._current_inputs = inputs_with_signals
3518    return features, labels
3519
3520  def signals(self):
3521    """Returns the `Signals` from `_Inputs`."""
3522    if self._current_inputs is None:
3523      raise RuntimeError(
3524          'Internal Error: The current inputs have not been properly '
3525          'generated. First call features_and_labels, then call signals.')
3526    signals = self._current_inputs['signals']
3527    self._current_inputs = None
3528    return signals
3529
3530  @staticmethod
3531  def insert_stopping_signal(stop, batch_size, add_padding=False):
3532    """Inserts stopping_signal into dataset via _map_fn.
3533
3534    Here we change the data structure in the dataset, such that the return value
3535    is a dictionary now and `features`, `labels`, and `signals` are three
3536    distinguished keys in that dict. This provides a better structure, which
3537    eases the process to decompose the inputs (see `features_and_labels`).
3538
3539    Args:
3540      stop: bool, state of current stopping signals.
3541      batch_size: int, batch size.
3542      add_padding: bool, whether to pad the tensor to full batch size.
3543
3544    Returns:
3545      A map_fn passed to dataset.map API.
3546    """
3547
3548    def _map_fn(*args):
3549      """The map fn to insert signals."""
3550      if len(args) == 1:
3551        # Unpack the single Tensor/dict argument as features. This is required
3552        # for the input_fn returns no labels.
3553        args = args[0]
3554      features, labels = _Inputs._parse_inputs(args)
3555      new_input_dict = {}
3556
3557      if add_padding:
3558        padding_mask, features, labels = (
3559            _PaddingSignals.pad_features_and_labels(features, labels,
3560                                                    batch_size))
3561
3562        new_input_dict['features'] = features
3563        if labels is not None:
3564          new_input_dict['labels'] = labels
3565
3566      else:
3567        new_input_dict['features'] = features
3568        if labels is not None:
3569          new_input_dict['labels'] = labels
3570        padding_mask = None
3571
3572      new_input_dict['signals'] = _StopSignals(
3573          stop=stop, batch_size=batch_size,
3574          padding_mask=padding_mask).as_dict()
3575
3576      return new_input_dict
3577
3578    return _map_fn
3579
3580
3581class _StopSignals(object):
3582  """Signals class holding all logic to handle TPU stopping condition."""
3583
3584  NON_STOPPING_SIGNAL = False
3585  STOPPING_SIGNAL = True
3586
3587  def __init__(self, stop, batch_size, padding_mask=None):
3588    self._stop = stop
3589    self._batch_size = batch_size
3590    self._padding_mask = padding_mask
3591
3592  def as_dict(self):
3593    """Returns the signals as Python dict."""
3594    shape = [self._batch_size, 1]
3595    dtype = dtypes.bool
3596
3597    if self._stop:
3598      stopping = array_ops.ones(shape=shape, dtype=dtype)
3599    else:
3600      stopping = array_ops.zeros(shape=shape, dtype=dtype)
3601
3602    signals = {'stopping': stopping}
3603    if self._padding_mask is not None:
3604      signals['padding_mask'] = self._padding_mask
3605    return signals
3606
3607  @staticmethod
3608  def as_scalar_stopping_signal(signals):
3609    return array_ops.identity(signals['stopping'][0][0])
3610
3611  @staticmethod
3612  def should_stop(scalar_stopping_signal):
3613    """Detects whether scalar_stopping_signal indicates stopping."""
3614    if isinstance(scalar_stopping_signal, ops.Tensor):
3615      # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF
3616      # way to express the bool check whether scalar_stopping_signal is True.
3617      return math_ops.logical_and(scalar_stopping_signal,
3618                                  _StopSignals.STOPPING_SIGNAL)
3619    else:
3620      # For non Tensor case, it is used in SessionRunHook. So, we cannot modify
3621      # the graph anymore. Here, we use pure Python.
3622      return bool(scalar_stopping_signal)
3623
3624
3625class _PaddingSignals(object):
3626  """Signals class holding all logic to handle padding."""
3627
3628  @staticmethod
3629  def pad_features_and_labels(features, labels, batch_size):
3630    """Pads out the batch dimension of features and labels."""
3631    real_batch_size = array_ops.shape(
3632        _PaddingSignals._find_any_tensor(features))[0]
3633
3634    batch_size_tensor = constant_op.constant(batch_size, dtypes.int32)
3635
3636    check_greater = check_ops.assert_greater_equal(
3637        batch_size_tensor,
3638        real_batch_size,
3639        data=(batch_size_tensor, real_batch_size),
3640        message='The real batch size should not be greater than batch_size.')
3641
3642    with ops.control_dependencies([check_greater]):
3643      missing_count = batch_size_tensor - real_batch_size
3644
3645    def pad_single_tensor(tensor):
3646      """Pads out the batch dimension of a tensor to the complete batch_size."""
3647      rank = len(tensor.shape)
3648      assert rank > 0
3649      padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))
3650      padded_shape = (batch_size,) + tuple(tensor.shape[1:])
3651      padded_tensor = array_ops.pad(tensor, padding)
3652      padded_tensor.set_shape(padded_shape)
3653      return padded_tensor
3654
3655    def nest_pad(tensor_or_dict):
3656      return nest.map_structure(pad_single_tensor, tensor_or_dict)
3657
3658    features = nest_pad(features)
3659    if labels is not None:
3660      labels = nest_pad(labels)
3661
3662    padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count,
3663                                                 batch_size)
3664
3665    return padding_mask, features, labels
3666
3667  @staticmethod
3668  def slice_tensor_or_dict(tensor_or_dict, signals):
3669    """Slice the real Tensors according to padding mask in signals."""
3670
3671    padding_mask = signals['padding_mask']
3672    batch_size = array_ops.shape(padding_mask)[0]
3673
3674    def verify_batch_size(tensor):
3675      check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
3676      with ops.control_dependencies([check_batch_size]):
3677        return array_ops.identity(tensor)
3678
3679    def slice_single_tensor(tensor):
3680      rank = len(tensor.shape)
3681      assert rank > 0
3682      real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
3683      return verify_batch_size(tensor)[0:real_batch_size]
3684
3685    # As we split the Tensors to all TPU cores and concat them back, it is
3686    # important to ensure the real data is placed before padded ones, i.e.,
3687    # order is preserved. By that, the sliced padding mask should have all 0's.
3688    # If this assertion failed, # the slice logic here would not hold.
3689    sliced_padding_mask = slice_single_tensor(padding_mask)
3690    assert_padding_mask = math_ops.equal(
3691        math_ops.reduce_sum(sliced_padding_mask), 0)
3692
3693    with ops.control_dependencies([assert_padding_mask]):
3694      should_stop = _StopSignals.should_stop(
3695          _StopSignals.as_scalar_stopping_signal(signals))
3696
3697    is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)
3698
3699    def slice_fn(tensor):
3700      # If the current batch is full batch or part of stopping signals, we do
3701      # not need to slice to save performance.
3702      return control_flow_ops.cond(
3703          math_ops.logical_or(should_stop, is_full_batch),
3704          (lambda: verify_batch_size(tensor)),
3705          (lambda: slice_single_tensor(tensor)))
3706
3707    return nest.map_structure(slice_fn, tensor_or_dict)
3708
3709  @staticmethod
3710  def _find_any_tensor(batch_features):
3711    tensors = [
3712        x for x in nest.flatten(batch_features) if isinstance(x, ops.Tensor)
3713    ]
3714    if not tensors:
3715      raise ValueError('Cannot find any Tensor in features dict.')
3716    return tensors[0]
3717
3718  @staticmethod
3719  def _padding_mask(real_batch_size, missing_count, batch_size):
3720    padding_mask = array_ops.concat([
3721        array_ops.zeros((real_batch_size,), dtype=dtypes.int32),
3722        array_ops.ones((missing_count,), dtype=dtypes.int32)
3723    ],
3724                                    axis=0)
3725    padding_mask.set_shape((batch_size,))
3726    return padding_mask
3727
3728
3729def _verify_cross_hosts_transfer_size(tensor_dict, message):
3730  total_size = 0
3731  tensor_structure = {}
3732  for key, tensor in tensor_dict.items():
3733    shape = tensor.shape
3734    size = np.product(shape) * tensor.dtype.size
3735    tensor_structure[key] = shape
3736    total_size += size
3737  if total_size >= _ONE_GIGABYTE:
3738    raise ValueError(
3739        '{} The transfer size is larger than the protobuf limit. Please '
3740        'consider to use Tensors with smaller shapes or reduce batch '
3741        'size. Given:\n'
3742        '{}'.format(
3743            message, '\n'.join([
3744                ' -- Key: {}, Shape: {}'.format(k, v)
3745                for k, v in tensor_structure.items()
3746            ])))
3747
3748
3749def _add_item_to_params(params, key, value):
3750  """Adds a new item into `params`."""
3751  if hasattr(params, 'set_hparam'):
3752    # For HParams, we need to use special API.
3753    if key in params:
3754      params.set_hparam(key, value)
3755    else:
3756      params.add_hparam(key, value)
3757  else:
3758    # Now params is Python dict.
3759    params[key] = value
3760
3761
3762def export_estimator_savedmodel(estimator,
3763                                export_dir_base,
3764                                serving_input_receiver_fn,
3765                                assets_extra=None,
3766                                as_text=False,
3767                                checkpoint_path=None,
3768                                strip_default_attrs=False):
3769  """Export `Estimator` trained model for TPU inference.
3770
3771  Args:
3772    estimator: `Estimator` with which model has been trained.
3773    export_dir_base: A string containing a directory in which to create
3774      timestamped subdirectories containing exported SavedModels.
3775    serving_input_receiver_fn: A function that takes no argument and returns a
3776      `ServingInputReceiver` or `TensorServingInputReceiver`.
3777    assets_extra: A dict specifying how to populate the assets.extra directory
3778      within the exported SavedModel, or `None` if no extra assets are needed.
3779    as_text: whether to write the SavedModel proto in text format.
3780    checkpoint_path: The checkpoint path to export.  If `None` (the default),
3781      the most recent checkpoint found within the model directory is chosen.
3782    strip_default_attrs: Boolean. If `True`, default-valued attributes will be
3783      removed from the NodeDefs.
3784
3785  Returns:
3786    The string path to the exported directory.
3787  """
3788  # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use
3789  # `estimator.config`.
3790  config = tpu_config.RunConfig(model_dir=estimator.model_dir)
3791  est = TPUEstimator(
3792      estimator._model_fn,  # pylint: disable=protected-access
3793      config=config,
3794      params=estimator.params,
3795      use_tpu=True,
3796      train_batch_size=2048,  # Does not matter.
3797      eval_batch_size=2048,  # Does not matter.
3798  )
3799  return est.export_savedmodel(export_dir_base, serving_input_receiver_fn,
3800                               assets_extra, as_text, checkpoint_path,
3801                               strip_default_attrs)
3802