• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""*Experimental* support for running Keras models on the TPU.
16
17To use, wrap your model with the `keras_support.tpu_model` function.
18
19Example usage:
20
21```
22image = tf.keras.layers.Input(shape=(28, 28, 3), name='image')
23c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image)
24flattened = tf.keras.layers.Flatten()(c1)
25logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
26model = tf.keras.Model(inputs=[image], outputs=[logits])
27
28resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
29strategy = keras_support.TPUDistributionStrategy(resolver)
30model = keras_support.tpu_model(model, strategy=strategy)
31
32# Only TF optimizers are currently supported.
33model.compile(optimizer=tf.train.AdamOptimizer(), ...)
34
35# `images` and `labels` should be Numpy arrays.  Support for tensor input
36# (e.g. datasets) is planned.
37model.fit(images, labels)
38```
39"""
40
41# pylint: disable=protected-access
42
43from __future__ import absolute_import
44from __future__ import division
45from __future__ import print_function
46
47import abc
48import collections
49import contextlib
50import re
51import sys
52import time
53
54import numpy as np
55import six
56
57from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
58from tensorflow.contrib.tpu.python.ops import tpu_ops
59from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
60from tensorflow.contrib.tpu.python.tpu import tpu
61from tensorflow.contrib.tpu.python.tpu import tpu_function
62from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
63from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
64from tensorflow.core.protobuf import config_pb2
65from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
66from tensorflow.python import tf2
67from tensorflow.python.client import session as tf_session
68from tensorflow.python.data.ops import dataset_ops
69from tensorflow.python.data.ops import iterator_ops
70from tensorflow.python.eager import context
71from tensorflow.python.estimator import model_fn as model_fn_lib
72from tensorflow.python.framework import constant_op
73from tensorflow.python.framework import dtypes
74from tensorflow.python.framework import errors
75from tensorflow.python.framework import ops
76from tensorflow.python.framework import tensor_shape
77from tensorflow.python.framework import tensor_spec
78from tensorflow.python.keras import backend as K
79from tensorflow.python.keras import callbacks as cbks
80from tensorflow.python.keras import metrics as metrics_module
81from tensorflow.python.keras import models
82from tensorflow.python.keras import optimizers as keras_optimizers
83from tensorflow.python.keras.engine import base_layer
84from tensorflow.python.keras.engine import base_layer_utils
85from tensorflow.python.keras.engine import training_arrays
86from tensorflow.python.keras.engine import training_utils
87from tensorflow.python.keras.layers import embeddings
88from tensorflow.python.keras.utils.generic_utils import make_batches
89from tensorflow.python.keras.utils.generic_utils import slice_arrays
90from tensorflow.python.ops import array_ops
91from tensorflow.python.ops import gen_linalg_ops
92from tensorflow.python.ops import math_ops
93from tensorflow.python.ops import random_ops
94from tensorflow.python.ops import variable_scope
95from tensorflow.python.ops import variables
96from tensorflow.python.platform import tf_logging as logging
97from tensorflow.python.util.deprecation import deprecated
98
99
100# TODO(b/114775106): temporary shim to optionally initialize the TPU
101# This increases the odds our session is initialized, but shouldn't be needed.
102_TEST_REWRITE_OP = None
103
104
105def _maybe_initialize_tpu(session):
106  """Initialize the TPU if it has not already been initialized."""
107  global _TEST_REWRITE_OP
108  try:
109    # Try to use cached version to avoid another ground of graph optimization.
110    test_rewrite_op = _TEST_REWRITE_OP
111    if (test_rewrite_op is None or
112        test_rewrite_op[0].graph != ops.get_default_graph()):
113
114      def test_op():
115        return constant_op.constant(1) + constant_op.constant(1)
116
117      test_rewrite_op = tpu.rewrite(test_op)
118      _TEST_REWRITE_OP = test_rewrite_op
119
120    session.run(test_rewrite_op)
121  except errors.FailedPreconditionError as _:
122    session.run(tpu.initialize_system())
123
124
125@contextlib.contextmanager
126def _tpu_session_context():
127  """Initialize the TPU and cleans cache entries for bad sessions."""
128  try:
129    _maybe_initialize_tpu(K.get_session())
130    yield
131  except (errors.FailedPreconditionError, errors.AbortedError) as e:
132    K.clear_session()
133    raise Exception("""
134An error occurred connecting or initializing your TPU.
135
136The session has been reset. re-run keras_to_tpu_model to create a new session.
137""" + str(e))
138
139
140def setup_tpu_session(cluster_resolver):
141  """Construct or return a `tf.Session` connected to the given cluster."""
142  master = cluster_resolver.master()
143
144  # Use the existing session if we're already connected to this TPU
145  # N.B K.get_session() is a non-trivial operation, and may fail if the remote
146  # session has been reset.
147  try:
148    default_session = K.get_session()
149    if (default_session._target == master and
150        getattr(default_session, '_tpu_initialized', None)):
151      return
152  except errors.AbortedError as _:
153    # We lost the remote session and need to re-initialize.
154    logging.warning('Lost remote session: creating a new session.')
155
156  cluster_spec = cluster_resolver.cluster_spec()
157  config = config_pb2.ConfigProto(isolate_session_state=True)
158  if cluster_spec:
159    config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
160
161  tpu_session = tf_session.Session(target=master, config=config)
162  tpu_session.run(tpu.initialize_system())
163  tpu_session._tpu_initialized = True
164
165  # N.B. We have to call `K.set_session()` AND set our session as the
166  # TF default. `K.get_session()` surprisingly does not return the value
167  # supplied by K.set_session otherwise.
168  K.set_session(tpu_session)
169
170
171try:
172  from scipy.sparse import issparse  # pylint: disable=g-import-not-at-top
173except ImportError:
174  issparse = None
175
176
177def get_tpu_system_metadata(tpu_cluster_resolver):
178  """Retrieves TPU system metadata given a TPUClusterResolver."""
179  master = tpu_cluster_resolver.master()
180
181  # pylint: disable=protected-access
182  cluster_spec = tpu_cluster_resolver.cluster_spec()
183  cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
184  tpu_system_metadata = (
185      tpu_system_metadata_lib._query_tpu_system_metadata(
186          master, cluster_def=cluster_def, query_topology=False))
187
188  return tpu_system_metadata
189
190
191class TPUDistributionStrategy(object):
192  """The strategy to run Keras model on TPU."""
193
194  def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
195    """Construct a TPUDistributionStrategy.
196
197    Args:
198      tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
199        create one with '' as master address.
200      using_single_core: Bool. This is the debugging option, which might be
201        removed in future once the model replication functionality is mature
202        enough. If `False` (default behavior), the system automatically finds
203        the best configuration, in terms of number of TPU cores, for the model
204        replication, typically using all available TPU cores. If overwrites as
205        `True`, force the model replication using single core, i.e., no
206        replication.
207    Raises:
208      Exception: No TPU Found on the given worker.
209    """
210    if tf2.enabled():
211      raise RuntimeError(
212          'Keras support is now deprecated in support of TPU Strategy. '
213          'Please follow the distribution strategy guide on tensorflow.org '
214          'to migrate to the 2.0 supported version.')
215    else:
216      logging.warning(
217          'Keras support is now deprecated in support of TPU Strategy. '
218          'Please follow the distribution strategy guide on tensorflow.org '
219          'to migrate to the 2.0 supported version.')
220    if tpu_cluster_resolver is None:
221      tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
222
223    metadata = get_tpu_system_metadata(tpu_cluster_resolver)
224    self._tpu_metadata = metadata
225    self._tpu_cluster_resolver = tpu_cluster_resolver
226    self._num_cores = 1 if using_single_core else metadata.num_cores
227
228    # Walk device list to identify TPU worker for enqueue/dequeue operations.
229    worker_re = re.compile('/job:([^/]+)')
230    for device in metadata.devices:
231      if 'TPU:0' in device.name:
232        self._worker_name = worker_re.search(device.name).group(1)
233        return
234    raise Exception('No TPU found on given worker.')
235
236  def _make_assignment_for_model(self, cpu_model):
237    """Makes a `TPUAssignment` for the passed in `cpu_model`."""
238    num_cores = self._num_cores
239    if num_cores > 1 and cpu_model.stateful:
240      logging.warning(
241          'Model replication does not currently support stateful models.  '
242          'Degrading to a single core.')
243      num_cores = 1
244
245    return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores)
246
247
248class TPUAssignment(object):
249  """This is object holding TPU resources assignment for the concrete model.
250
251  `TPUDistributionStrategy` is responsible to create the instance of
252  `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
253  model and input batch sizes.
254  """
255
256  def __init__(self, worker_name, num_cores):
257    self._worker_name = worker_name
258    self._num_cores = num_cores
259
260  @property
261  def worker_name(self):
262    return self._worker_name
263
264  @property
265  def num_towers(self):
266    # TODO(xiejw): Support automatically assign num_cores based on inputs.
267    return self._num_cores
268
269
270class TPUEmbedding(embeddings.Embedding):
271  """TPU compatible embedding layer.
272
273  The default Keras layer is not TPU compatible.  This layer is a drop-in
274  replacement: it has the same behavior and will work on CPU and GPU devices.
275  """
276
277  def build(self, input_shape):
278    if input_shape[0] is None:
279      raise ValueError(
280          'TPUEmbeddings must have a fixed input_length or input shape.')
281    return super(TPUEmbedding, self).build(input_shape)
282
283  def call(self, inputs):
284    if K.dtype(inputs) != 'int32':
285      inputs = math_ops.cast(inputs, 'int32')
286
287    inputs = array_ops.one_hot(inputs, self.input_dim)
288    return math_ops.tensordot(inputs, self.embeddings, 1)
289
290
291def _cross_replica_concat(tensor, core_id, num_cores, name):
292  """Concatenate `tensor` across cores.
293
294  Args:
295    tensor: The tensor to be concatenated. Must be [int32 and float32].
296    core_id: Tensor indicating the current TPU core.
297    num_cores: Python int. The total number of TPU cores in the system.
298    name: The string name to print for debugging.
299
300  Returns:
301    The same concatenated Tensor on each core.
302  """
303
304  input_dtype = tensor.dtype
305  if input_dtype not in [dtypes.bfloat16, dtypes.float32, dtypes.int32]:
306    raise TypeError('For model replication, only (bfloat16, float32 and int32) '
307                    'is supported for model outputs and targets. Got {} for '
308                    '{}.'.format(input_dtype, name))
309
310  batch_size = tensor.shape[0]
311  mask = math_ops.cast(
312      math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id),
313      dtypes.float32)
314  mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims)
315  result = mask * math_ops.cast(tensor, dtypes.float32)
316  local_tensor_with_holes = array_ops.reshape(result,
317                                              [-1] + result.shape.as_list()[2:])
318  concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes)
319  concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:]))
320
321  if concat_tensor != input_dtype:
322    concat_tensor = math_ops.cast(concat_tensor, input_dtype)
323  return concat_tensor
324
325
326class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
327  """An optimizer that averages gradients across TPU shards."""
328
329  def __init__(self, opt, name='KerasCrossShardOptimizer'):
330    """Construct a new cross-shard optimizer.
331
332    Args:
333      opt: An existing `Optimizer` to encapsulate.
334      name: Optional name prefix for the operations created when applying
335        gradients. Defaults to "KerasCrossShardOptimizer".
336
337    Raises:
338      ValueError: If reduction is not a valid cross-shard reduction.
339    """
340    super(KerasCrossShardOptimizer, self).__init__()
341    self._name = name
342    self._opt = opt
343    logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights)
344
345  def get_updates(self, loss, params):
346    self._opt.get_gradients = self.get_gradients
347    return self._opt.get_updates(loss, params)
348
349  def get_gradients(self, loss, params):
350    num_shards = tpu_function.get_tpu_context().number_of_shards
351    grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
352    return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
353
354  def get_weights(self):
355    return self._opt.get_weights()
356
357  def get_config(self):
358    return self._opt.get_config()
359
360  # Defer remaining operations to the underlying optimizer
361  def __getattr__(self, key):
362    return getattr(self._opt, key)
363
364
365class TPUModelOp(
366    collections.namedtuple('TPUModelOp', [
367        'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'
368    ])):
369  pass
370
371
372def _valid_name(tensor_name):
373  """Return a valid tensor name (strips '/', ':', etc)."""
374  return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
375
376
377def _replicated_optimizer(opt):
378  """Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
379  # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
380  # single core.  This ensures Keras properly tracks and initializes optimizer
381  # variables.
382  if isinstance(opt, keras_optimizers.TFOptimizer):
383    return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
384  else:
385    return KerasCrossShardOptimizer(opt)
386
387
388def _clone_optimizer(optimizer, config=None, worker_name=None):
389  """Returns a cloned optimizer with the provided optimizer.config or config."""
390  if not isinstance(optimizer, keras_optimizers.Optimizer):
391    # In the first call to tpu_model(model), Keras may not have wrapped the TF
392    # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled
393    # or optimizer isn't set, and later generated tpu_model compiles with a TF
394    # optimizer.
395    return optimizer
396
397  if isinstance(optimizer, keras_optimizers.TFOptimizer):
398    return keras_optimizers.TFOptimizer(optimizer.optimizer)
399
400  if config is None:
401    config = optimizer.get_config()
402  logging.info('Cloning %s %s', optimizer.__class__.__name__, config)
403  with ops.device(
404      '%s/device:CPU:0' % ('/job:%s' % worker_name if worker_name else '')):
405    # Explicitly put optimizer parameter variables on TPU worker.
406    return optimizer.__class__.from_config(config)
407
408
409class TPURewriteContext(object):
410  """Prepare the environment for a Keras model during `tpu.rewrite`.
411
412  This overrides the default placeholder behaviour to instead refer to a preset
413  input mapping.  Placeholders are unsupported in TPU compiled code, and must
414  be replaced with explicit inputs or values from the infeed queue.
415
416  Instead of explicitly threading inputs all the way through the Keras codebase,
417  we override the behavior of the placeholder while compiling and inject the
418  Tensors from the infeed in place of the placeholder.
419
420  Similarly, as we compile a new sub-graph for each unique shape and execution
421  mode, we need to override the behavior of an embedded `name_scope` call in
422  the base Keras layer code.  This allows us to re-use the same weights across
423  many compiles and share a single session/graph.
424  """
425
426  def __init__(self, input_map):
427    self._input_map = input_map
428    self._default_placeholder = None
429    self._default_name_scope = None
430
431  def __enter__(self):
432
433    def _placeholder(dtype, shape=None, name=None):  # pylint: disable=unused-argument
434      logging.info('Remapping placeholder for %s', name)
435      if name in self._input_map:
436        return self._input_map[name]
437      else:
438        logging.info('Default: %s', name)
439        return self._default_placeholder(dtype, shape, name)
440
441    def _name_scope(name, default_name=None, values=None):
442      caller_frame = sys._getframe().f_back
443      caller_obj = caller_frame.f_locals.get('self')
444      if (caller_obj is not None and
445          isinstance(caller_obj, base_layer.Layer) and name is not None):
446        return variable_scope.variable_scope(
447            name, default_name, values, reuse=variable_scope.AUTO_REUSE)
448
449      return self._default_name_scope(name, default_name, values)
450
451    self._default_placeholder = array_ops.placeholder
452    self._default_name_scope = ops.name_scope
453    self._default_make_variable = base_layer_utils.make_variable
454    self._default_random_normal = random_ops.random_normal
455    self._default_qr = gen_linalg_ops.qr
456
457    array_ops.placeholder = _placeholder
458
459    # Replace random_ops.random_normal with a dummy function because
460    # `random_normal` isn't yet implemented on the TPU. Because these
461    # initialized values are overwritten by the CPU values, this is okay.
462    def random_normal(shape,
463                      mean=0.0,
464                      stddev=1.0,
465                      dtype=dtypes.float32,
466                      seed=None,
467                      name=None):
468      del mean
469      del stddev
470      del seed
471      return array_ops.zeros(shape, dtype=dtype, name=name)
472
473    random_ops.random_normal = random_normal
474
475    # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented.
476    # TODO(saeta): Remove qr override once we confirm the qr implementation is
477    # ok.
478    # pylint: disable=redefined-builtin
479    def qr(input, full_matrices=False, name=None):
480      """Dummy implementation of qr decomposition."""
481      del full_matrices  # TODO(saeta): Properly handle the full matrix case.
482      input_shape = input.shape
483      if len(input_shape) < 2:
484        raise ValueError('Invalid shape passed to qr: %s' % input_shape)
485      p = min(input_shape[-1], input_shape[-2])
486      if len(input_shape) == 2:
487        q = array_ops.zeros((p, p), name=name)
488        r = array_ops.zeros(input_shape, name=name)
489        return (r, q)
490      elif len(input_shape) == 3:
491        n = input_shape[0]
492        q = array_ops.zeros((n, p, p), name=name)
493        r = array_ops.zeros(input_shape, name=name)
494        return (r, q)
495      else:
496        raise ValueError('Invalid shape passed to qr: %s' % input_shape)
497
498    gen_linalg_ops.qr = qr
499
500    ops.name_scope = _name_scope
501    base_layer_utils.make_variable = variable_scope.get_variable
502    logging.info('Overriding default placeholder.')
503    return
504
505  def __exit__(self, exc_type, exc_val, exc_tb):
506    array_ops.placeholder = self._default_placeholder
507    ops.name_scope = self._default_name_scope
508    base_layer_utils.make_variable = self._default_make_variable
509    random_ops.random_normal = self._default_random_normal
510    gen_linalg_ops.qr = self._default_qr
511
512
513class SizedInfeed(
514    collections.namedtuple('SizedInfeed',
515                           ['sharded_infeed_tensors', 'infeed_ops'])):
516  """Represents an instantiation of the infeed ops for a concrete input shape.
517
518  sharded_infeed_tensors: A data structure of Tensors used to represent the
519    placeholder tensors that must be fed when using feed_dicts.
520
521  infeed_ops: the set of ops that will be run to drive infeed for a single step.
522  """
523  pass
524
525
526class TPUInfeedInstance(object):
527  """TPUInfeedInstance represents the logic to manage feeding in a single step.
528
529  See the comments on the `TPUInfeedManager` for a description for how infeed
530  is managed.
531  """
532
533  @abc.abstractmethod
534  def make_input_specs(self, input_tensors):
535    """Constructs the infeed_specs for the given Infeed instance.
536
537    Args:
538      input_tensors: The inputs to the model.
539
540    Returns:
541      A list of
542    """
543    pass
544
545  def make_feed_dict(self, tpu_model_op):
546    """Constructs a feed_dict for this instance, given the tpu_model_op.
547
548    Args:
549      tpu_model_op: A `TPUModelOp` representing the TPU Model for this
550        instance's input spec.
551
552    Returns:
553      A dictionary to use as the feed_dict of a `session.run` call.
554    """
555    pass
556
557
558@six.add_metaclass(abc.ABCMeta)
559class TPUInfeedManager(object):
560  """TPUInfeedManager manages the data infeeding of data to a TPU computation.
561
562  Because there are multiple data sources (e.g. in-memory NumPy arrays,
563  `tf.data.Dataset`s), we abstract the different logic behind a single
564  interface: the `TPUInfeedManager`.
565
566  (1) A `TPUFunction` is called with a set of inputs. Based on the inputs,
567  `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a
568  new one if required).
569
570  (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager`
571  which returns a `TPUInfeedInstance`.
572
573  (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of
574  the model based on the returned `input_specs` from `TPUInfeedInstance`.
575
576  (4) [Optional.] If the model has not already been instantiated for the given
577  input spec, the `TPUFunction` compiles the model for the input spec (using the
578  `TPUInfeedManager`).
579
580  (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the
581  compiled model instance corresponding to its shape.
582  """
583
584  @abc.abstractmethod
585  def make_infeed_instance(self, inputs):
586    """Given a single step's input, construct a `TPUInfeedInstance`.
587
588    Args:
589      inputs: The inputs to a given step.
590
591    Returns:
592      A subclass of `TPUInfeedInstance`.
593    """
594    pass
595
596  @abc.abstractmethod
597  def build_infeed_from_input_specs(self, input_specs, execution_mode):
598    """For a given input specification (size, type), construct the infeed ops.
599
600    This is called only once for a given input specification and builds the
601    graph ops. It does not have a pointer to the actual infeed data.
602
603    Args:
604      input_specs: TODO(saeta): Document me!
605      execution_mode: TODO(saeta): Document me!
606
607    Returns:
608      A `SizedInfeed` instance.
609    """
610    pass
611
612
613class TPUNumpyInfeedManager(TPUInfeedManager):
614  """TPU Infeed manager for Numpy inputs."""
615
616  class NumpyInfeedInstance(TPUInfeedInstance):
617    """Infeed instance for Numpy inputs."""
618
619    def __init__(self, sharded_inputs):
620      self._sharded_inputs = sharded_inputs
621
622    def make_input_specs(self, input_tensors):
623      # Compute an input specification (used to generate infeed enqueue and
624      # dequeue operations).  We use the shape from our input array and the
625      # dtype from our model.  A user may pass in a float64 for a float32
626      # input: for model compatibility we still must generate a float32 infeed.
627      input_specs = []
628      # We use the shape and dtype from the first shard to compute the input
629      # metadata (`input_specs`); all replicas have the same type and shape.
630      for tensor, ary in zip(input_tensors, self._sharded_inputs[0]):
631        input_specs.append(
632            tensor_spec.TensorSpec(ary.shape, tensor.dtype,
633                                   _valid_name(tensor.name)))
634
635      return input_specs
636
637    def make_feed_dict(self, tpu_model_op):
638      infeed_dict = {}
639      for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors,
640                                        self._sharded_inputs):
641        for tensor, value in zip(infeed_tensors, inputs):
642          infeed_dict[tensor] = value
643      return infeed_dict
644
645  def __init__(self, tpu_assignment):
646    self._tpu_assignment = tpu_assignment
647
648  def _split_tensors(self, inputs):
649    """Split input data across shards.
650
651    Each input is sliced along the batch axis.
652
653    Args:
654      inputs: List of Numpy arrays to run on the TPU.
655
656    Returns:
657      List of lists containing the input to feed to each TPU shard.
658    """
659    if self._tpu_assignment.num_towers == 1:
660      return [inputs]
661
662    batch_size = inputs[0].shape[0]
663    assert batch_size % self._tpu_assignment.num_towers == 0, (
664        'batch_size must be divisible by the number of TPU cores in use (%s '
665        'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
666    shard_size = batch_size // self._tpu_assignment.num_towers
667    input_list = []
668    for index in range(self._tpu_assignment.num_towers):
669      shard_inputs = [
670          x[index * shard_size:(index + 1) * shard_size] for x in inputs
671      ]
672      input_list.append(shard_inputs)
673    return input_list
674
675  def make_infeed_instance(self, inputs):
676    sharded_inputs = self._split_tensors(inputs)
677    return self.NumpyInfeedInstance(sharded_inputs)
678
679  def build_infeed_from_input_specs(self, input_specs, execution_mode):
680    infeed_op = []
681    shard_infeed_tensors = []
682
683    for shard_id in range(self._tpu_assignment.num_towers):
684      with ops.device(
685          '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
686        infeed_tensors = []
687        with ops.device('/device:TPU:%d' % shard_id):
688          for spec in input_specs:
689            # Construct placeholders for each of the inputs.
690            infeed_tensors.append(
691                array_ops.placeholder(
692                    dtype=spec.dtype,
693                    shape=spec.shape,
694                    name='infeed-enqueue-%s-%d' % (spec.name, shard_id)))
695        shard_infeed_tensors.append(infeed_tensors)
696
697        infeed_op.append(
698            tpu_ops.infeed_enqueue_tuple(
699                infeed_tensors, [spec.shape for spec in input_specs],
700                name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
701                device_ordinal=shard_id))
702    return SizedInfeed(
703        infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors)
704
705
706class TPUDatasetInfeedManager(TPUInfeedManager):
707  """Manages infeed for a `tf.data.Dataset` into a TPU computation.
708
709  """
710
711  class DatasetInfeedInstance(TPUInfeedInstance):
712    """An instance of the TPU infeed."""
713
714    def __init__(self, input_specs):
715      self._input_specs = input_specs
716
717    def make_input_specs(self, input_tensors):
718      # TODO(saeta): Do error checking here!
719      return self._input_specs
720
721    def make_feed_dict(self, tpu_model_op):
722      # TODO(saeta): Verify tpu_model_op is as expected!
723      return {}
724
725  # pylint: disable=redefined-outer-name
726  def __init__(self, dataset, tpu_assignment, mode):
727    """Constructs a TPUDatasetInfeedManager.
728
729    Args:
730      dataset: A `tf.data.Dataset` to infeed.
731      tpu_assignment: The `TPUAssignment` used to configure the
732        Keras TPU model.
733      mode: ModeKeys enum.
734    """
735    self._verify_dataset_shape(dataset)
736
737    self._dataset = dataset
738    self._tpu_assignment = tpu_assignment
739    dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
740    dummy_x_shape = dataset_output_shapes[0].as_list()
741    dummy_x_shape[0] *= tpu_assignment.num_towers
742    dummy_y_shape = dataset_output_shapes[1].as_list()
743    dummy_y_shape[0] *= tpu_assignment.num_towers
744    self._iterator = dataset_ops.make_initializable_iterator(dataset)
745    K.get_session().run(self._iterator.initializer)
746
747    self._get_next_ops = []
748    ctrl_deps = []
749    for i in range(tpu_assignment.num_towers):
750      with ops.control_dependencies(ctrl_deps):  # Ensure deterministic
751        # TODO(saeta): Ensure correct placement!
752        get_next_op = self._iterator.get_next()
753        self._get_next_ops.append(get_next_op)
754        ctrl_deps.extend(get_next_op)
755
756    # Use dummy numpy inputs for the rest of Keras' shape checking. We
757    # intercept them when building the model.
758    dataset_output_types = dataset_ops.get_legacy_output_types(dataset)
759    self._dummy_x = np.zeros(
760        dummy_x_shape, dtype=dataset_output_types[0].as_numpy_dtype)
761    self._dummy_y = np.zeros(
762        dummy_y_shape, dtype=dataset_output_types[1].as_numpy_dtype)
763
764    input_specs = []
765    iterator_output_shapes = dataset_ops.get_legacy_output_shapes(
766        self._iterator)
767    iterator_output_types = dataset_ops.get_legacy_output_types(self._iterator)
768    if isinstance(iterator_output_shapes, tuple):
769      assert isinstance(iterator_output_types, tuple)
770      assert len(iterator_output_shapes) == len(iterator_output_types)
771      for i in range(len(iterator_output_shapes)):
772        spec = tensor_spec.TensorSpec(iterator_output_shapes[i],
773                                      iterator_output_types[i])
774        input_specs.append(spec)
775    elif isinstance(iterator_output_shapes, tensor_shape.TensorShape):
776      spec = tensor_spec.TensorSpec(iterator_output_shapes,
777                                    iterator_output_types)
778      input_specs.append(spec)
779
780    # Pre-process the inputs and get_next_ops before caching.
781    input_specs, self._get_next_ops = (
782        _inject_tpu_inputs_for_dataset(
783            tpu_assignment, mode, input_specs, self._get_next_ops))
784    self._infeed_instance = self.DatasetInfeedInstance(input_specs)
785
786  def _verify_dataset_shape(self, dataset):
787    """Verifies a dataset is of an appropriate shape for TPUs."""
788    dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
789    dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset)
790    if not isinstance(dataset, dataset_ops.DatasetV2):
791      raise ValueError('The function passed as the `x` parameter did not '
792                       'return a `tf.data.Dataset`.')
793    if not isinstance(dataset_output_classes, tuple):
794      raise ValueError('The dataset must return a tuple of tf.Tensors, '
795                       'instead it returns: %s' % dataset_output_classes)
796    if len(dataset_output_classes) != 2:
797      raise ValueError('The dataset must return a 2-element tuple, got '
798                       '%s output classes instead.' % (dataset_output_classes,))
799    for i, cls in enumerate(dataset_output_classes):
800      if cls != ops.Tensor:
801        raise ValueError('The dataset returned a non-Tensor type (%s) at '
802                         'index %d.' % (cls, i))
803    for i, shape in enumerate(dataset_output_shapes):
804      if not shape:
805        raise ValueError('The dataset returns a scalar tensor in '
806                         'tuple index %d. Did you forget to batch? '
807                         '(Output shapes: %s).' % (i, dataset_output_shapes))
808      for j, dim in enumerate(shape):
809        if dim.value is None:
810          if j == 0:
811            hint = (' Hint: did you use `ds.batch(BATCH_SIZE, '
812                    'drop_remainder=True)`?')
813          else:
814            hint = ''
815          raise ValueError(
816              'The Keras-TPU integration for `tf.data` '
817              'currently requires static shapes. The provided '
818              'dataset only has a partially defined shape. '
819              '(Dimension %d of output tensor %d is not statically known '
820              'for output shapes: %s.%s)' % (j, i, dataset_output_shapes, hint))
821
822  @property
823  def dummy_x(self):
824    return self._dummy_x
825
826  @property
827  def dummy_y(self):
828    return self._dummy_y
829
830  def make_infeed_instance(self, inputs):
831    # TODO(saeta): Verify inputs is as expected.
832    return self._infeed_instance
833
834  def build_infeed_from_input_specs(self, input_specs, execution_mode):
835    shard_infeed_tensors = self._get_next_ops
836    assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
837    infeed_ops = []
838    for shard_id in range(self._tpu_assignment.num_towers):
839      with ops.device(
840          '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
841        infeed_ops.append(
842            tpu_ops.infeed_enqueue_tuple(
843                shard_infeed_tensors[shard_id],
844                [spec.shape for spec in input_specs],
845                name='infeed-enqueue-%s-%d' % (execution_mode, shard_id),
846                device_ordinal=shard_id))
847    return SizedInfeed(
848        infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors)
849
850
851def _inject_tpu_inputs_for_dataset(tpu_assignment, mode,
852                                   input_specs, get_next_ops):
853  """Append core information to the set of dataset inputs."""
854  # This is used during compilation to identify the current TPU core and enable
855  # concatenation operations across cores.
856  if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
857    return input_specs, get_next_ops
858
859  # Dataset inputs operate on per core basis.
860  per_core_batch_size = input_specs[0].shape.as_list()[0]
861
862  # Insert, at head, the tensor for core_id.
863  assert len(get_next_ops) == tpu_assignment.num_towers
864  for i in range(tpu_assignment.num_towers):
865    core_id_constant = constant_op.constant(
866        np.array([i] * per_core_batch_size).astype('int32'),
867        dtype=dtypes.int32,
868        name='cord_id_constant')
869    get_next_ops[i] = [core_id_constant] + list(get_next_ops[i])
870
871  # Insert the input spec at head also.
872  input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32)
873                ] + input_specs
874
875  return input_specs, get_next_ops
876
877
878def _inject_tpu_inputs_for_infeed(tpu_assignment, mode,
879                                  core_id_place_holder, input_tensors, inputs):
880  """Append core information to the set of inputs."""
881  # This is used during compilation to identify the current TPU core and enable
882  # concatenation operations across cores.
883  if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
884    return input_tensors, inputs
885
886  # Puts a place holder in input spec.
887  input_tensors = [core_id_place_holder] + input_tensors
888
889  # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the
890  # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id
891  # (duplicated).
892  num_cores = tpu_assignment.num_towers
893  per_core_batch_size = inputs[0].shape[0] // num_cores
894  core_ids = np.arange(num_cores).repeat(per_core_batch_size)
895  inputs = [core_ids] + inputs
896  return input_tensors, inputs
897
898
899def _read_tpu_coreid_from_infeed(mode, infeed_tensors):
900  """Popping out the core ids from infeed."""
901  if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]:
902    return None, infeed_tensors
903
904  if len(infeed_tensors) <= 1:
905    raise RuntimeError(
906        'The infeed tensors on TPU core has only {} tensors. '
907        'This is not expected. Please report a bug.\nTensors: {}'.format(
908            len(infeed_tensors), infeed_tensors))
909
910  core_id = infeed_tensors[0][0]  # Pop out the scalar version.
911  rest = infeed_tensors[1:]
912  return core_id, rest
913
914
915class TPUFunction(object):
916  """K.function compatible interface for invoking a TPU compiled function.
917
918  Recompilation is triggered on-demand for each set of new inputs shapes: the
919  results are cached for future execution.  We expect most computations will
920  be dominated by a standard batch-size, followed by a straggler batch for
921  the end of training or evaluation.
922
923  All `inputs` and `outputs` will be loaded via the infeed and outfeed queues
924  instead of being injected as `feed_dict` items or fetches.
925  """
926
927  def __init__(self, model, execution_mode, tpu_assignment):
928    self.model = model
929    self.execution_mode = execution_mode
930    self._tpu_assignment = tpu_assignment
931    self._compilation_cache = {}
932    self._cloned_model = None
933    self._cloned_optimizer = None
934    # Create a placeholder for the TPU core ID. Cache the placeholder to avoid
935    # modifying the graph for every batch.
936    self._core_id_place_holder = array_ops.placeholder(
937        dtype=dtypes.int32, shape=[1], name='core_id')
938
939  def _specialize_model(self, input_specs, infeed_manager):
940    """Specialize `self.model` (a Keras model) for the given input shapes."""
941    # Re-create our input and output layers inside our subgraph.  They will be
942    # attached to the true computation when we clone our model in `tpu_fn`.
943    K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)
944
945    # functools.partial and callable objects are not supported by tpu.rewrite
946    def _model_fn():
947      """Compute fit/eval/predict for the TPU."""
948      is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
949      is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
950      is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT
951
952      # During train/eval, we infeed our features as well as labels.
953      if is_training or is_test:
954        infeed_layers = self.model._input_layers + self.model._output_layers
955      else:
956        infeed_layers = self.model._input_layers
957
958      # Generate our infeed operation to read features & labels.
959      infeed_tensors = tpu_ops.infeed_dequeue_tuple(
960          dtypes=[spec.dtype for spec in input_specs],
961          shapes=[spec.shape for spec in input_specs],
962          name='infeed-%s' % self.execution_mode)
963
964      core_id, infeed_tensors = (
965          _read_tpu_coreid_from_infeed(
966              mode=self.execution_mode, infeed_tensors=infeed_tensors))
967
968      assert len(infeed_tensors) == len(infeed_layers), (
969          'Infeed inputs did not match model: %s vs %s' % (infeed_layers,
970                                                           infeed_tensors))
971
972      tpu_targets = []
973      tpu_input_map = {}
974
975      # Sort infeed outputs into inputs and labels for calling our Keras model.
976      for tensor, layer in zip(infeed_tensors, infeed_layers):
977        if layer in self.model._input_layers:
978          tpu_input_map[layer.name] = tensor
979        if layer in self.model._output_layers:
980          tpu_targets.append(tensor)
981
982      # Clone our CPU model, running within the TPU device context.
983      #
984      # We use the id of the original model as a key to avoid weight collisions
985      # (if a user re-runs the same model multiple times, in e.g. Colab).
986      with TPURewriteContext(tpu_input_map):
987        with variable_scope.variable_scope('tpu_%s' % id(self.model)):
988          with keras_tpu_variables.replicated_scope(
989              self._tpu_assignment.num_towers):
990            if not self._cloned_optimizer:
991              self._cloned_optimizer = _clone_optimizer(
992                  self.model.cpu_optimizer,
993                  worker_name=self._tpu_assignment.worker_name)
994
995            self._cloned_model = models.clone_model(self.model)
996
997            # When running on more than one core, concatenate outputs at the end
998            # of processing. In backprop stage, the gradients will be
999            # calculated according to the local inputs as gradient of
1000            # cross-replica-concat being zero for any outputs other than those
1001            # from mlocal core so the loss calculation is identical.
1002            num_towers = self.model._tpu_assignment.num_towers
1003            if num_towers > 1 and (is_training or is_test):
1004              new_outputs = [
1005                  _cross_replica_concat(
1006                      o, core_id, num_towers,
1007                      name='model output ({})'.format(o.name))
1008                  for o in self._cloned_model.outputs
1009              ]
1010              # Recast all low precision outputs back to float32 since we only
1011              # casted the inputs to bfloat16 and not targets. This is done so
1012              # that we can preserve precision when calculating the loss value.
1013              if new_outputs and new_outputs[0].dtype == dtypes.bfloat16:
1014                new_outputs = [
1015                    math_ops.cast(o, dtypes.float32) for o in new_outputs]
1016              self._cloned_model.outputs = new_outputs
1017              tpu_targets = [
1018                  _cross_replica_concat(
1019                      tensor,
1020                      core_id,
1021                      num_towers,
1022                      name='model target ({})'.format(tensor.name))
1023                  for tensor in tpu_targets
1024              ]
1025
1026          if is_training or is_test:
1027            with variable_scope.variable_scope(
1028                'metrics', reuse=variable_scope.AUTO_REUSE):
1029              self._cloned_model.compile(
1030                  optimizer=_replicated_optimizer(self._cloned_optimizer),
1031                  loss=self.model.loss,
1032                  loss_weights=self.model.loss_weights,
1033                  metrics=metrics_module.clone_metrics(
1034                      self.model._compile_metrics),
1035                  weighted_metrics=metrics_module.clone_metrics(
1036                      self.model._compile_weighted_metrics),
1037                  target_tensors=tpu_targets,
1038              )
1039
1040      # Compute our outfeed depending on the execution mode
1041      if is_training:
1042        if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
1043          # For Keras optimizer, we try to place the variable weights on the TPU
1044          # device. Keras creates optimizer variables (e.g. momentum values for
1045          # the Momentum optimizer) when _make_train_function is invoked.
1046          with keras_tpu_variables.replicated_variable_for_optimizer(
1047              self._tpu_assignment.num_towers):
1048            self._cloned_model._make_train_function()
1049        else:
1050          self._cloned_model._make_train_function()
1051
1052        self._outfeed_spec = [
1053            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
1054            for tensor in self._cloned_model.train_function.outputs
1055        ]
1056        return [
1057            self._cloned_model.train_function.updates_op,
1058            tpu_ops.outfeed_enqueue_tuple(
1059                self._cloned_model.train_function.outputs,
1060                name='outfeed-enqueue-train')
1061        ]
1062      elif is_test:
1063        self._cloned_model._make_test_function()
1064        self._outfeed_spec = [
1065            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
1066            for tensor in self._cloned_model.test_function.outputs
1067        ]
1068        return [
1069            tpu_ops.outfeed_enqueue_tuple(
1070                self._cloned_model.test_function.outputs,
1071                name='outfeed-enqueue-test')
1072        ]
1073      elif is_predict:
1074        self._cloned_model._make_predict_function()
1075        self._outfeed_spec = [
1076            tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
1077            for tensor in self._cloned_model.predict_function.outputs
1078        ]
1079        return [
1080            tpu_ops.outfeed_enqueue_tuple(
1081                self._cloned_model.predict_function.outputs,
1082                name='outfeed-enqueue-predict',
1083            )
1084        ]
1085      else:
1086        assert False, 'Unexpected execution mode: %s' % self.execution_mode
1087
1088    # Capture outfeed metadata computed during the rewrite.
1089    self._outfeed_spec = None
1090
1091    # Generate out TPU operations using `tpu.split_compile_and_replicate`.
1092    # `compile_op` can be used to test the TPU model compiles before execution.
1093    # `execute op` replicates `_model_fn` `num_replicas` times, with each shard
1094    # running on a different logical core.
1095    compile_op, execute_op = tpu.split_compile_and_replicate(
1096        _model_fn, inputs=[[] for _ in range(self._tpu_assignment.num_towers)])
1097
1098    # Generate CPU side operations to enqueue features/labels and dequeue
1099    # outputs from the model call.
1100    sized_infeed = infeed_manager.build_infeed_from_input_specs(
1101        input_specs, self.execution_mode)
1102    # Build output ops.
1103    outfeed_op = []
1104    for shard_id in range(self._tpu_assignment.num_towers):
1105      with ops.device(
1106          '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
1107        outfeed_op.extend(
1108            tpu_ops.outfeed_dequeue_tuple(
1109                dtypes=[spec.dtype for spec in self._outfeed_spec],
1110                shapes=[spec.shape for spec in self._outfeed_spec],
1111                name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id),
1112                device_ordinal=shard_id))
1113
1114    return TPUModelOp(
1115        compile_op,
1116        execute_op,
1117        infeed_tensors=sized_infeed.sharded_infeed_tensors,
1118        infeed_op=sized_infeed.infeed_ops,
1119        outfeed_op=outfeed_op)
1120
1121  def _test_model_compiles(self, tpu_model_ops):
1122    """Verifies that the given TPUModelOp can be compiled via XLA."""
1123    logging.info('Started compiling')
1124    start_time = time.time()
1125
1126    result = K.get_session().run(tpu_model_ops.compile_op)
1127    proto = tpu_compilation_result.CompilationResultProto()
1128    proto.ParseFromString(result)
1129    if proto.status_error_message:
1130      raise RuntimeError('Compilation failed: {}'.format(
1131          proto.status_error_message))
1132
1133    end_time = time.time()
1134    logging.info('Finished compiling. Time elapsed: %s secs',
1135                 end_time - start_time)
1136
1137  def _lookup_infeed_manager(self, inputs):
1138    """Return an existing manager, or construct a new InfeedManager for inputs.
1139
1140    _lookup_infeed_manager will return an existing InfeedManager if one has been
1141    previously assigned for this model and input. If not, it will construct a
1142    new TPUNumpyInfeedManager.
1143
1144    Args:
1145      inputs: A NumPy input to the model.
1146
1147    Returns:
1148      A `TPUInfeedManager` object to manage infeeds for this input.
1149    """
1150    if inputs is None:
1151      return None
1152
1153    for x, mgr in self.model._numpy_to_infeed_manager_list:
1154      if inputs[0] is x:
1155        return mgr
1156
1157    return TPUNumpyInfeedManager(self.model._tpu_assignment)
1158
1159  def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
1160    """Looks up the corresponding `TPUModelOp` for a given `input_specs`.
1161
1162    It instantiates a new copy of the model for each unique input shape.
1163
1164    Args:
1165      input_specs: The specification of the inputs to train on.
1166      infeed_manager: The infeed manager responsible for feeding in data.
1167
1168    Returns:
1169      A `TPUModelOp` instance that can be used to execute a step of the model.
1170    """
1171    if input_specs is None or infeed_manager is None:
1172      # Note: this condition is possible during the prologue or epilogue of the
1173      # pipelined loop.
1174      return None
1175
1176    # XLA requires every operation in the graph has a fixed shape.  To
1177    # handle varying batch sizes we recompile a new sub-graph for each
1178    # unique input shape.
1179    shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
1180    if shape_key not in self._compilation_cache:
1181      logging.info(
1182          'New input shapes; (re-)compiling: mode=%s '
1183          '(# of cores %d), %s', self.execution_mode,
1184          self._tpu_assignment.num_towers, input_specs)
1185      new_tpu_model_ops = self._specialize_model(input_specs,
1186                                                 infeed_manager)
1187      self._compilation_cache[shape_key] = new_tpu_model_ops
1188      self._test_model_compiles(new_tpu_model_ops)
1189
1190    return self._compilation_cache[shape_key]
1191
1192  def _construct_input_tensors_and_inputs(self, inputs):
1193    """Returns input tensors and numpy array inputs corresponding to `inputs`.
1194
1195    Args:
1196      inputs: NumPy inputs.
1197
1198    Returns:
1199      A tuple of `input_tensors`, and `inputs`.
1200    """
1201    if inputs is None:
1202      # Note: this condition is possible during the prologue or epilogue of the
1203      # pipelined loop.
1204      return None, None
1205
1206    if isinstance(inputs[-1], int):
1207      # Remove the learning_phase flag at the end. We currently hard code the
1208      # learning_phase in TPUFunction.
1209      inputs = inputs[:-1]
1210
1211    if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
1212        self.execution_mode == model_fn_lib.ModeKeys.EVAL):
1213      # Strip sample weight from inputs.
1214      input_tensors = self.model._feed_inputs + self.model._feed_targets
1215    else:
1216      input_tensors = self.model._feed_inputs
1217
1218    inputs = inputs[:len(input_tensors)]
1219    input_tensors, inputs = (
1220        _inject_tpu_inputs_for_infeed(
1221            self._tpu_assignment, self.execution_mode,
1222            self._core_id_place_holder, input_tensors, inputs))
1223    return input_tensors, inputs
1224
1225  def _process_outputs(self, outfeed_outputs):
1226    """Processes the outputs of a model function execution.
1227
1228    Args:
1229      outfeed_outputs: The sharded outputs of the TPU computation.
1230
1231    Returns:
1232      The aggregated outputs of the TPU computation to be used in the rest of
1233      the model execution.
1234    """
1235    # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
1236    if self.execution_mode == model_fn_lib.ModeKeys.PREDICT:
1237      outputs = [[] for _ in range(len(self._outfeed_spec))]
1238      outputs_per_replica = len(self._outfeed_spec)
1239
1240      for i in range(self._tpu_assignment.num_towers):
1241        output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
1242                                       outputs_per_replica]
1243        for j in range(outputs_per_replica):
1244          outputs[j].append(output_group[j])
1245
1246      return [np.concatenate(group) for group in outputs]
1247    else:
1248      return outfeed_outputs[:len(outfeed_outputs) //
1249                             self._tpu_assignment.num_towers]
1250
1251  def __call__(self, inputs):
1252    """__call__ executes the function on the computational hardware.
1253
1254    It handles executing infeed, and preprocessing in addition to executing the
1255    model on the TPU hardware.
1256
1257    Note: `__call__` has a sibling method `pipeline_run` which performs the same
1258    operations, but with software pipelining.
1259
1260    Args:
1261      inputs: The inputs to use to train.
1262
1263    Returns:
1264      The output of the computation for the given mode it is executed in.
1265
1266    Raises:
1267      RuntimeError: If there is an inappropriate use of the function.
1268    """
1269    assert isinstance(inputs, list)
1270
1271    infeed_manager = self._lookup_infeed_manager(inputs)
1272    input_tensors, inputs = self._construct_input_tensors_and_inputs(inputs)
1273    infeed_instance = infeed_manager.make_infeed_instance(inputs)
1274    del inputs  # To avoid accident usage.
1275    input_specs = infeed_instance.make_input_specs(input_tensors)
1276    tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs,
1277                                                        infeed_manager)
1278    infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
1279
1280    # Initialize our TPU weights on the first compile.
1281    self.model._initialize_weights(self._cloned_model)
1282
1283    _, _, outfeed_outputs = K.get_session().run([
1284        tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
1285        tpu_model_ops.outfeed_op
1286    ], infeed_dict)
1287    return self._process_outputs(outfeed_outputs)
1288
1289  def pipeline_run(self, cur_step_inputs, next_step_inputs):
1290    """pipeline_run executes the function on the computational hardware.
1291
1292    pipeline_run performs the same computation as __call__, however it runs the
1293    infeed in a software pipelined fashion compared to the on-device execution.
1294
1295    Note: it is the responsibility of the caller to call `pipeline_run` in the
1296    following sequence:
1297      - Once with `cur_step_inputs=None` and `next_step_inputs=list(...)`
1298      - `n` times with `cur_step_inputs` and `next_step_inputs` as `list`s
1299      - Once with `cur_step_inputs=list(...)` and `next_step_inputs=None`
1300    Additionally, it is the responsibility of the caller to pass
1301    `next_step_inputs` as `cur_step_inputs` on the next invocation of
1302    `pipeline_run`.
1303
1304    Args:
1305      cur_step_inputs: The current step's inputs.
1306      next_step_inputs: The next step's inputs.
1307
1308    Returns:
1309      The output of the computation for the given mode it is executed in.
1310
1311    Raises:
1312      RuntimeError: If there is an inappropriate use of the function.
1313    """
1314    # Software pipelined case.
1315    next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
1316    cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
1317
1318    if (next_step_infeed_manager is not None and
1319        cur_step_infeed_manager is not None):
1320      assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
1321
1322    next_input_tensors, next_step_inputs = (
1323        self._construct_input_tensors_and_inputs(next_step_inputs))
1324    cur_input_tensors, cur_step_inputs = (
1325        self._construct_input_tensors_and_inputs(cur_step_inputs))
1326
1327    cur_infeed_instance = None
1328    if cur_step_infeed_manager:
1329      cur_infeed_instance = cur_step_infeed_manager.make_infeed_instance(
1330          cur_step_inputs)
1331    next_infeed_instance = None
1332    if next_step_infeed_manager:
1333      next_infeed_instance = next_step_infeed_manager.make_infeed_instance(
1334          next_step_inputs)
1335
1336    del cur_step_inputs  # Avoid accidental re-use.
1337    del next_step_inputs  # Avoid accidental re-use.
1338
1339    cur_tpu_model_ops = None
1340    next_tpu_model_ops = None
1341    infeed_dict = None
1342
1343    if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
1344      cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors)
1345      cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
1346          cur_input_specs, cur_step_infeed_manager)
1347
1348    if (next_infeed_instance and next_input_tensors and
1349        next_step_infeed_manager):
1350      next_input_specs = next_infeed_instance.make_input_specs(
1351          next_input_tensors)
1352      next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
1353          next_input_specs, next_step_infeed_manager)
1354      infeed_dict = next_infeed_instance.make_feed_dict(next_tpu_model_ops)
1355
1356    # Initialize our TPU weights on the first compile.
1357    self.model._initialize_weights(self._cloned_model)
1358
1359    if next_tpu_model_ops and cur_tpu_model_ops:
1360      _, _, outfeed_outputs = K.get_session().run([
1361          next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
1362          cur_tpu_model_ops.outfeed_op
1363      ], infeed_dict)
1364      return self._process_outputs(outfeed_outputs)
1365
1366    if cur_tpu_model_ops:
1367      _, outfeed_outputs = K.get_session().run(
1368          [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
1369      return self._process_outputs(outfeed_outputs)
1370
1371    if next_tpu_model_ops:
1372      K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict)
1373      return None
1374    raise RuntimeError('Internal error: both current & next tpu_model_ops '
1375                       'were None')
1376
1377
1378class KerasTPUModel(models.Model):
1379  """TPU compatible Keras model wrapper."""
1380
1381  def __init__(self, cpu_model, strategy):
1382    super(models.Model, self).__init__(  # pylint: disable=bad-super-call
1383        inputs=cpu_model.inputs,
1384        outputs=cpu_model.outputs,
1385        name=cpu_model.name,
1386    )
1387    if tf2.enabled():
1388      raise RuntimeError(
1389          'Keras support is now deprecated in support of TPU Strategy. '
1390          'Please follow the distribution strategy guide on tensorflow.org '
1391          'to migrate to the 2.0 supported version.')
1392    else:
1393      logging.warning(
1394          'Keras support is now deprecated in support of TPU Strategy. '
1395          'Please follow the distribution strategy guide on tensorflow.org '
1396          'to migrate to the 2.0 supported version.')
1397    # Create a mapping from numpy arrays to infeed managers.
1398    # Note: uses a list of tuples instead of a map because numpy arrays are
1399    # not hashable.
1400    self._numpy_to_infeed_manager_list = []
1401
1402    # Add distribution specific arguments since we don't call the Model init.
1403    self._distribution_strategy = None
1404    self._compile_distribution = None
1405
1406    self.predict_function = None
1407    self.test_function = None
1408    self.train_function = None
1409    self._stateful_metric_functions = []
1410
1411    cluster_resolver = strategy._tpu_cluster_resolver
1412    self._tpu_name_or_address = cluster_resolver.get_master()
1413    self._cpu_model = cpu_model
1414    self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
1415    self._tpu_model = None
1416    self._tpu_weights_initialized = False
1417
1418    # If the input CPU model has already been compiled, compile our TPU model
1419    # immediately.
1420    if self._cpu_model.optimizer:
1421      self.compile(
1422          self._cpu_model.optimizer,
1423          self._cpu_model.loss,
1424          self._cpu_model._compile_metrics,
1425          self._cpu_model.loss_weights,
1426          self._cpu_model.sample_weight_mode,
1427          self._cpu_model._compile_weighted_metrics,
1428          self._cpu_model.target_tensors,
1429      )
1430
1431    # This flag must be disabled upon model mutation, such as changing the model
1432    # layers or recompiling the model to use a different optimizer. New function
1433    # definitions are generated whenever this flag is disabled, ensuring that
1434    # internal graph functions are always using the current model structure.
1435    #
1436    # Requires declaration here because this constructor skips the
1437    # Model constructor.
1438    self._built_graph_functions = False
1439
1440  def get_config(self):
1441    return {
1442        'cpu_model': self._cpu_model,
1443        'tpu_name_or_address': self._tpu_name_or_address,
1444        'tpu_assignment': self._tpu_assignment,
1445    }
1446
1447  def compile(self,
1448              optimizer,
1449              loss=None,
1450              metrics=None,
1451              loss_weights=None,
1452              sample_weight_mode=None,
1453              weighted_metrics=None,
1454              target_tensors=None,
1455              **kwargs):
1456    if sample_weight_mode:
1457      raise ValueError('sample_weight_mode not supported for TPU execution.')
1458    if weighted_metrics:
1459      raise ValueError('weighted_metrics not supported for TPU execution.')
1460    if target_tensors:
1461      raise ValueError('target_tensors is not supported for TPU execution.')
1462
1463    self._cpu_model.compile(
1464        _clone_optimizer(optimizer), loss,
1465        metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode,
1466        metrics_module.clone_metrics(weighted_metrics), target_tensors,
1467        **kwargs)
1468
1469    super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
1470                                       sample_weight_mode, weighted_metrics,
1471                                       target_tensors, **kwargs)
1472
1473  def fit(self,
1474          x=None,
1475          y=None,
1476          batch_size=None,
1477          epochs=1,
1478          verbose=1,
1479          callbacks=None,
1480          validation_split=0.,
1481          validation_data=None,
1482          shuffle=True,
1483          class_weight=None,
1484          sample_weight=None,
1485          initial_epoch=0,
1486          steps_per_epoch=None,
1487          validation_steps=None,
1488          **kwargs):
1489    if context.executing_eagerly():
1490      raise EnvironmentError('KerasTPUModel currently does not support eager '
1491                             'mode.')
1492
1493    with _tpu_session_context():
1494      assert not self._numpy_to_infeed_manager_list  # Ensure empty.
1495
1496      infeed_managers = []  # Managers to clean up at the end of the fit call.
1497      if isinstance(x, dataset_ops.DatasetV2):
1498        # TODO(b/111413240): Support taking a tf.data.Dataset directly.
1499        raise ValueError(
1500            'Taking a Dataset directly is not yet supported. Please '
1501            'wrap your dataset construction code in a function and '
1502            'pass that to fit instead. For examples, see: '
1503            'https://github.com/tensorflow/tpu/tree/master/models/experimental'
1504            '/keras')
1505      if callable(x):
1506        with ops.device(
1507            '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
1508          dataset = x()
1509          if steps_per_epoch is None:
1510            raise ValueError('When using tf.data as input to a model, you '
1511                             'should specify the steps_per_epoch argument.')
1512          if y is not None:
1513            raise ValueError('When using tf.data as input to a model, y must '
1514                             'be None')
1515          infeed_manager = TPUDatasetInfeedManager(
1516              dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
1517          # Use dummy numpy inputs for the rest of Keras' shape checking. We
1518          # intercept them when building the model.
1519          x = infeed_manager.dummy_x
1520          y = infeed_manager.dummy_y
1521          infeed_managers.append((x, infeed_manager))
1522
1523      if isinstance(validation_data, dataset_ops.DatasetV2):
1524        # TODO(b/111413240): Support taking a tf.data.Dataset directly.
1525        raise ValueError(
1526            'Taking a Dataset directly is not yet supported. Please '
1527            'wrap your dataset construction code in a function and '
1528            'pass that to fit instead. For examples, see: '
1529            'https://github.com/tensorflow/tpu/tree/master/models/experimental'
1530            '/keras')
1531      if callable(validation_data):
1532        dataset = validation_data()
1533        if validation_steps is None:
1534          raise ValueError('When using tf.data as validation for a model, you '
1535                           'should specify the validation_steps argument.')
1536        infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
1537                                                 model_fn_lib.ModeKeys.EVAL)
1538        # Use dummy numpy inputs for the rest of Keras' shape checking. We
1539        # intercept them when building the model.
1540        val_x = infeed_manager.dummy_x
1541        val_y = infeed_manager.dummy_y
1542        infeed_managers.append((val_x, infeed_manager))
1543        validation_data = (val_x, val_y)
1544
1545      self._numpy_to_infeed_manager_list = infeed_managers
1546      try:
1547        pipeline = kwargs.get('_pipeline', True)
1548        if '_pipeline' in kwargs:
1549          kwargs.pop('_pipeline')
1550        if not pipeline:
1551          logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
1552                       pipeline)
1553          return super(KerasTPUModel, self).fit(
1554              x, y, batch_size, epochs, verbose, callbacks, validation_split,
1555              validation_data, shuffle, class_weight, sample_weight,
1556              initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1557        return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks,
1558                                  validation_split, validation_data, shuffle,
1559                                  class_weight, sample_weight, initial_epoch,
1560                                  steps_per_epoch, validation_steps, **kwargs)
1561      finally:
1562        self._numpy_to_infeed_manager_list = []
1563
1564  def evaluate(self,
1565               x=None,
1566               y=None,
1567               batch_size=None,
1568               verbose=1,
1569               sample_weight=None,
1570               steps=None):
1571    original_numpy_to_infeed_manager_list = []
1572    if self._numpy_to_infeed_manager_list:
1573      # evaluate call may be executed as callbacks during the training. In this
1574      # case, _numpy_to_infeed_manager_list is not empty, so save it for
1575      # recovery at the end of evaluate call.
1576      original_numpy_to_infeed_manager_list = self._numpy_to_infeed_manager_list
1577      self._numpy_to_infeed_manager_list = []
1578
1579    with _tpu_session_context():
1580      # Managers to clean up at the end of the evaluate call.
1581      infeed_managers = []
1582      if isinstance(x, dataset_ops.DatasetV2):
1583        # TODO(b/111413240): Support taking a tf.data.Dataset directly.
1584        raise ValueError(
1585            'Taking a Dataset directly is not yet supported. Please '
1586            'wrap your dataset construction code in a function and '
1587            'pass that to fit instead. For examples, see: '
1588            'https://github.com/tensorflow/tpu/tree/master/models/experimental'
1589            '/keras')
1590      if callable(x):
1591        dataset = x()
1592        if steps is None:
1593          raise ValueError('When using tf.data as input to a model, you '
1594                           'should specify the steps argument.')
1595        if y is not None:
1596          raise ValueError('When using tf.data as input to a model, y must be '
1597                           'None')
1598        infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
1599                                                 model_fn_lib.ModeKeys.EVAL)
1600        # Use dummy numpy inputs for the rest of Keras' shape checking. We
1601        # intercept them when building the model.
1602        x = infeed_manager.dummy_x
1603        y = infeed_manager.dummy_y
1604        infeed_managers.append((x, infeed_manager))
1605
1606      self._numpy_to_infeed_manager_list = infeed_managers
1607      try:
1608        return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
1609                                                   sample_weight, steps)
1610      finally:
1611        self._numpy_to_infeed_manager_list = (
1612            original_numpy_to_infeed_manager_list)
1613
1614  def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
1615                    validation_split, validation_data, shuffle, class_weight,
1616                    sample_weight, initial_epoch, steps_per_epoch,
1617                    validation_steps, **kwargs):
1618    # Similar to super.fit(...), but modified to support software pipelining.
1619
1620    # Backwards compatibility
1621    if batch_size is None and steps_per_epoch is None:
1622      batch_size = 32
1623    # Legacy support
1624    if 'nb_epoch' in kwargs:
1625      logging.warning('The `nb_epoch` argument in `fit` has been renamed '
1626                      '`epochs`.')
1627      epochs = kwargs.pop('nb_epoch')
1628    if kwargs:
1629      raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
1630
1631    # Validate and standardize user data
1632    x, y, sample_weights = self._standardize_user_data(
1633        x,
1634        y,
1635        sample_weight=sample_weight,
1636        class_weight=class_weight,
1637        batch_size=batch_size,
1638        check_steps=True,
1639        steps_name='steps_per_epoch',
1640        steps=steps_per_epoch,
1641        validation_split=validation_split)
1642
1643    # Prepare validation data
1644    val_x, val_y, val_sample_weights = self._prepare_validation_data(
1645        validation_data, validation_split, validation_steps, x, y,
1646        sample_weights, batch_size)
1647    return self._pipeline_fit_loop(
1648        x,
1649        y,
1650        sample_weights=sample_weights,
1651        batch_size=batch_size,
1652        epochs=epochs,
1653        verbose=verbose,
1654        callbacks=callbacks,
1655        val_inputs=val_x,
1656        val_targets=val_y,
1657        val_sample_weights=val_sample_weights,
1658        shuffle=shuffle,
1659        initial_epoch=initial_epoch,
1660        steps_per_epoch=steps_per_epoch,
1661        validation_steps=validation_steps)
1662
1663  def _pipeline_fit_loop(self,
1664                         inputs,
1665                         targets,
1666                         sample_weights,
1667                         batch_size,
1668                         epochs,
1669                         verbose,
1670                         callbacks,
1671                         val_inputs,
1672                         val_targets,
1673                         val_sample_weights,
1674                         shuffle,
1675                         initial_epoch,
1676                         steps_per_epoch,
1677                         validation_steps):
1678    self._make_train_function()
1679    sample_weights = sample_weights or []
1680    val_sample_weights = val_sample_weights or []
1681    if not isinstance(K.learning_phase(), int):
1682      ins = inputs + targets + sample_weights + [1]
1683    else:
1684      ins = inputs + targets + sample_weights
1685
1686    do_validation = False
1687    if val_inputs:
1688      do_validation = True
1689      if (steps_per_epoch is None and verbose and inputs and
1690          hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
1691        print('Train on %d samples, validate on %d samples' %
1692              (inputs[0].shape[0], val_inputs[0].shape[0]))
1693
1694    if validation_steps:
1695      do_validation = True
1696      if steps_per_epoch is None:
1697        raise ValueError('Can only use `validation_steps` when doing step-wise '
1698                         'training, i.e. `steps_per_epoch` must be set.')
1699
1700    num_training_samples = training_utils.check_num_samples(
1701        ins, batch_size, steps_per_epoch, 'steps_per_epoch')
1702    count_mode = 'steps' if steps_per_epoch else 'samples'
1703    callbacks = cbks.configure_callbacks(
1704        callbacks,
1705        self,
1706        do_validation=do_validation,
1707        batch_size=batch_size,
1708        epochs=epochs,
1709        steps_per_epoch=steps_per_epoch,
1710        samples=num_training_samples,
1711        verbose=verbose,
1712        count_mode=count_mode)
1713
1714    if num_training_samples is not None:
1715      index_array = np.arange(num_training_samples)
1716
1717    # To prevent a slowdown, we find beforehand the arrays that need conversion.
1718    feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
1719    indices_for_conversion_to_dense = []
1720    for i in range(len(feed)):
1721      if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
1722        indices_for_conversion_to_dense.append(i)
1723
1724    callbacks.on_train_begin()
1725    for epoch in range(initial_epoch, epochs):
1726      # Reset stateful metrics
1727      for m in self.metrics:
1728        m.reset_states()
1729      # Update callbacks
1730      callbacks.on_epoch_begin(epoch)
1731      epoch_logs = {}
1732      if steps_per_epoch is not None:
1733        # Step-wise fit loop.
1734        self._pipeline_fit_loop_step_wise(
1735            ins=ins,
1736            callbacks=callbacks,
1737            steps_per_epoch=steps_per_epoch,
1738            epochs=epochs,
1739            do_validation=do_validation,
1740            val_inputs=val_inputs,
1741            val_targets=val_targets,
1742            val_sample_weights=val_sample_weights,
1743            validation_steps=validation_steps,
1744            epoch_logs=epoch_logs)
1745      else:
1746        # Sample-wise fit loop.
1747        self._pipeline_fit_loop_sample_wise(
1748            ins=ins,
1749            callbacks=callbacks,
1750            index_array=index_array,
1751            shuffle=shuffle,
1752            batch_size=batch_size,
1753            num_training_samples=num_training_samples,
1754            indices_for_conversion_to_dense=indices_for_conversion_to_dense,
1755            do_validation=do_validation,
1756            val_inputs=val_inputs,
1757            val_targets=val_targets,
1758            val_sample_weights=val_sample_weights,
1759            validation_steps=validation_steps,
1760            epoch_logs=epoch_logs)
1761
1762      callbacks.on_epoch_end(epoch, epoch_logs)
1763      if callbacks.model.stop_training:
1764        break
1765    callbacks.on_train_end()
1766    return self.history
1767
1768  def _pipeline_fit_loop_sample_wise(self,
1769                                     ins,
1770                                     callbacks,
1771                                     index_array,
1772                                     shuffle,
1773                                     batch_size,
1774                                     num_training_samples,
1775                                     indices_for_conversion_to_dense,
1776                                     do_validation,
1777                                     val_inputs,
1778                                     val_targets,
1779                                     val_sample_weights,
1780                                     validation_steps,
1781                                     epoch_logs):
1782    f = self.train_function
1783    if shuffle == 'batch':
1784      index_array = training_utils.batch_shuffle(index_array, batch_size)
1785    elif shuffle:
1786      np.random.shuffle(index_array)
1787    batches = make_batches(num_training_samples, batch_size)
1788
1789    ins_last_batch = None
1790    last_batch_logs = None
1791    batch_index = 0
1792
1793    for batch_index, (batch_start, batch_end) in enumerate(batches):
1794      batch_ids = index_array[batch_start:batch_end]
1795      try:
1796        if isinstance(ins[-1], int):
1797          # Do not slice the training phase flag.
1798          ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
1799        else:
1800          ins_batch = slice_arrays(ins, batch_ids)
1801      except TypeError:
1802        raise TypeError('TypeError while preparing batch. If using HDF5 '
1803                        'input data, pass shuffle="batch".')
1804
1805      # Pipeline batch logs
1806      next_batch_logs = {}
1807      next_batch_logs['batch'] = batch_index
1808      next_batch_logs['size'] = len(batch_ids)
1809      if batch_index > 0:
1810        # Callbacks operate one step behind in software pipeline.
1811        callbacks.on_batch_begin(batch_index - 1, last_batch_logs)
1812      for i in indices_for_conversion_to_dense:
1813        ins_batch[i] = ins_batch[i].toarray()
1814
1815      outs = f.pipeline_run(
1816          cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch)
1817      ins_last_batch = ins_batch
1818
1819      if batch_index == 0:
1820        assert outs is None
1821      else:
1822        if not isinstance(outs, list):
1823          outs = [outs]
1824        for l, o in zip(self.metrics_names, outs):
1825          last_batch_logs[l] = o  # pylint: disable=unsupported-assignment-operation
1826        callbacks.on_batch_end(batch_index - 1, last_batch_logs)
1827        if callbacks.model.stop_training:
1828          return
1829      last_batch_logs = next_batch_logs
1830
1831    # Final batch
1832    callbacks.on_batch_begin(batch_index, last_batch_logs)
1833    outs = f.pipeline_run(cur_step_inputs=ins_last_batch, next_step_inputs=None)
1834    if not isinstance(outs, list):
1835      outs = [outs]
1836    for l, o in zip(self.metrics_names, outs):
1837      last_batch_logs[l] = o
1838    callbacks.on_batch_end(batch_index, last_batch_logs)
1839    if callbacks.model.stop_training:
1840      return
1841
1842    if do_validation:
1843      val_outs = training_arrays.test_loop(
1844          self,
1845          val_inputs,
1846          val_targets,
1847          sample_weights=val_sample_weights,
1848          batch_size=batch_size,
1849          steps=validation_steps,
1850          verbose=0)
1851      if not isinstance(val_outs, list):
1852        val_outs = [val_outs]
1853      # Same labels assumed.
1854      for l, o in zip(self.metrics_names, val_outs):
1855        epoch_logs['val_' + l] = o
1856
1857  def _pipeline_fit_loop_step_wise(self,
1858                                   ins,
1859                                   callbacks,
1860                                   steps_per_epoch,
1861                                   epochs,
1862                                   do_validation,
1863                                   val_inputs,
1864                                   val_targets,
1865                                   val_sample_weights,
1866                                   validation_steps,
1867                                   epoch_logs):
1868    f = self.train_function
1869
1870    # Loop prologue
1871    try:
1872      outs = f.pipeline_run(cur_step_inputs=None, next_step_inputs=ins)
1873      assert outs is None  # Function shouldn't return anything!
1874    except errors.OutOfRangeError:
1875      logging.warning('Your dataset iterator ran out of data on the first step '
1876                      'of the epoch, preventing further training. Check to '
1877                      'make sure your paths are correct and you have '
1878                      'permissions to read the files. Skipping validation')
1879
1880    for step_index in range(steps_per_epoch):
1881      batch_logs = {'batch': step_index, 'size': 1}
1882      callbacks.on_batch_begin(step_index, batch_logs)
1883      try:
1884        if step_index < steps_per_epoch - 1:
1885          next_step_inputs = ins
1886        else:
1887          next_step_inputs = None
1888        outs = f.pipeline_run(
1889            cur_step_inputs=ins, next_step_inputs=next_step_inputs)
1890      except errors.OutOfRangeError:
1891        logging.warning('Your dataset iterator ran out of data; '
1892                        'interrupting training. Make sure that your '
1893                        'dataset can generate at least `steps_per_batch * '
1894                        'epochs` batches (in this case, %d batches). You '
1895                        'may need to use the repeat() function when '
1896                        'building your dataset.' % steps_per_epoch * epochs)
1897        break
1898
1899      if not isinstance(outs, list):
1900        outs = [outs]
1901      for l, o in zip(self.metrics_names, outs):
1902        batch_logs[l] = o
1903
1904      callbacks.on_batch_end(step_index, batch_logs)
1905      if callbacks.model.stop_training:
1906        break
1907
1908    if do_validation:
1909      val_outs = training_arrays.test_loop(
1910          self,
1911          val_inputs,
1912          val_targets,
1913          sample_weights=val_sample_weights,
1914          steps=validation_steps,
1915          verbose=0)
1916      if not isinstance(val_outs, list):
1917        val_outs = [val_outs]
1918      # Same labels assumed.
1919      for l, o in zip(self.metrics_names, val_outs):
1920        epoch_logs['val_' + l] = o
1921
1922  def _prepare_validation_data(self, validation_data, validation_split,
1923                               validation_steps, x, y, sample_weights,
1924                               batch_size):
1925    """Prepares the validation dataset.
1926
1927    Args:
1928      validation_data: The validation data (if provided)
1929      validation_split: The validation split (if provided)
1930      validation_steps: The validation steps (if provided)
1931      x: The main training data x (if provided)
1932      y: The main training data y (if provided)
1933      sample_weights: The sample weights (if provided)
1934      batch_size: The training batch size (if provided)
1935
1936    Returns:
1937      A 3-tuple of (val_x, val_y, val_sample_weights).
1938
1939    Raises:
1940      ValueError: If the provided arguments are not compatible with
1941        `KerasTPUModel`.
1942    """
1943    # Note: this is similar to a section of $tf/python/keras/engine/training.py
1944    # It differns in that tf.data objects are not allowed to be passed directly.
1945    # Additionally, it handles validating shapes & types appropriately for use
1946    # in TPUs.
1947    if validation_data:
1948      if (isinstance(validation_data, iterator_ops.Iterator) or
1949          isinstance(validation_data, iterator_ops.EagerIterator) or
1950          isinstance(validation_data, dataset_ops.DatasetV2)):
1951        raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator '
1952                         'for validation_data. Please instead pass a function '
1953                         'that returns a `tf.data.Dataset`.')
1954      if len(validation_data) == 2:
1955        val_x, val_y = validation_data  # pylint: disable=unpacking-non-sequence
1956        val_sample_weight = None
1957      elif len(validation_data) == 3:
1958        val_x, val_y, val_sample_weight = validation_data  # pylint: disable=unpacking-non-sequence
1959      else:
1960        raise ValueError('When passing a `validation_data` argument, it must '
1961                         'contain either 2 items (x_val, y_val), or 3 items '
1962                         '(x_val, y_val, val_sample_weights). However we '
1963                         'received `validation_data=%s`' % validation_data)
1964      val_x, val_y, val_sample_weights = self._standardize_user_data(
1965          val_x,
1966          val_y,
1967          sample_weight=val_sample_weight,
1968          batch_size=batch_size,
1969          steps=validation_steps)
1970    elif validation_split and 0. < validation_split < 1.:
1971      if training_utils.has_symbolic_tensors(x):
1972        raise ValueError('If your data is in the form of symbolic tensors, you '
1973                         'cannot use `validation_split`.')
1974      if hasattr(x[0], 'shape'):
1975        split_at = int(x[0].shape[0] * (1. - validation_split))
1976      else:
1977        split_at = int(len(x[0]) * (1. - validation_split))
1978
1979      x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
1980      y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
1981      sample_weights, val_sample_weights = (
1982          slice_arrays(sample_weights, 0, split_at),
1983          slice_arrays(sample_weights, split_at)
1984      )
1985    elif validation_steps:
1986      val_x = []
1987      val_y = []
1988      val_sample_weights = []
1989    else:
1990      val_x = None
1991      val_y = None
1992      val_sample_weights = None
1993
1994    return val_x, val_y, val_sample_weights
1995
1996  def predict(self,
1997              x,
1998              batch_size=None,
1999              verbose=0,
2000              steps=None,
2001              max_queue_size=10,
2002              workers=1,
2003              use_multiprocessing=False):
2004    with _tpu_session_context():
2005      return super(KerasTPUModel, self).predict(
2006          x,
2007          batch_size=batch_size,
2008          verbose=verbose,
2009          steps=steps,
2010          max_queue_size=max_queue_size,
2011          workers=workers,
2012          use_multiprocessing=use_multiprocessing)
2013
2014  @property
2015  def optimizer(self):
2016    if self._tpu_model:
2017      return self._tpu_model.optimizer
2018    return self._cpu_model.optimizer
2019
2020  @optimizer.setter
2021  def optimizer(self, optimizer):
2022    self._optimizer = optimizer
2023
2024  @property
2025  def metrics(self):
2026    if self._tpu_model:
2027      return self._tpu_model.metrics
2028    return self._stateful_metric_functions
2029
2030  @metrics.setter
2031  def metrics(self, metrics):
2032    self._stateful_metric_functions = metrics
2033
2034  def _make_train_function(self):
2035    if not self.train_function:
2036      self.train_function = TPUFunction(
2037          self,
2038          model_fn_lib.ModeKeys.TRAIN,
2039          tpu_assignment=self._tpu_assignment)
2040
2041    return self.train_function
2042
2043  def _make_test_function(self):
2044    if not self.test_function:
2045      self.test_function = TPUFunction(
2046          self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
2047    return self.test_function
2048
2049  def _make_predict_function(self):
2050    if not self.predict_function:
2051      self.predict_function = TPUFunction(
2052          self,
2053          model_fn_lib.ModeKeys.PREDICT,
2054          tpu_assignment=self._tpu_assignment)
2055    return self.predict_function
2056
2057  def _initialize_weights(self, cloned_model):
2058    """Initialize TPU weights.
2059
2060    This is called on the first compile of the TPU model (first call to
2061    fit/predict/evaluate).
2062
2063    Args:
2064      cloned_model: `keras.Model`, TPU model to initialize.
2065    """
2066    if self._tpu_weights_initialized:
2067      return
2068
2069    self._tpu_model = cloned_model
2070    self._tpu_weights_initialized = True
2071
2072    weights = self._cpu_model.get_weights()
2073
2074    if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
2075      cpu_optimizer_config = {}
2076    else:
2077      cpu_optimizer_config = self.cpu_optimizer.get_config()
2078
2079    logging.info('Setting weights on TPU model.')
2080    cloned_model.set_weights(weights)
2081    if self._tpu_model.optimizer is None:
2082      # tpu_model may not be compiled, e.g., loading weights and then predict.
2083      return
2084    for k, v in six.iteritems(cpu_optimizer_config):
2085      if k == 'name':
2086        continue
2087      opt_var = getattr(self._tpu_model.optimizer, k)
2088      if isinstance(opt_var, variables.Variable):
2089        logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var))
2090        K.get_session().run(opt_var.assign(v))
2091      else:
2092        logging.warning('Cannot update non-variable config: %s', k)
2093
2094  @property
2095  def cpu_optimizer(self):
2096    return self._cpu_model.optimizer
2097
2098  def sync_to_cpu(self):
2099    """Copy weights from the CPU, returning a synchronized CPU model."""
2100    if not self._tpu_weights_initialized:
2101      return self._cpu_model
2102
2103    logging.info('Copying TPU weights to the CPU')
2104    tpu_weights = self._tpu_model.get_weights()
2105
2106    # TFOptimizers have no configurable options
2107    if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer):
2108      tpu_optimizer_config = {}
2109    else:
2110      tpu_optimizer_config = self._tpu_model.optimizer.get_config()
2111
2112    self._cpu_model.set_weights(tpu_weights)
2113    for k, v in six.iteritems(tpu_optimizer_config):
2114      logging.info('TPU -> CPU %s: %s', k, v)
2115      if k == 'name':
2116        continue
2117      opt_var = getattr(self.cpu_optimizer, k)
2118      if isinstance(opt_var, variables.Variable):
2119        K.get_session().run(opt_var.assign(v))
2120      else:
2121        logging.warning('Cannot update non-variable config: %s', k)
2122
2123    return self._cpu_model
2124
2125  def get_weights(self):
2126    return self.sync_to_cpu().get_weights()
2127
2128  def save_weights(self, *args, **kw):
2129    return self.sync_to_cpu().save_weights(*args, **kw)
2130
2131  def save(self, *args, **kw):
2132    return self.sync_to_cpu().save(*args, **kw)
2133
2134  def set_weights(self, weights):
2135    # We may not have a TPU model available if we haven't run fit/predict, so
2136    # we can't directly set the TPU weights here.
2137    # Instead, reset CPU model weights and force TPU re-initialization at the
2138    # next call.
2139    self._cpu_model.set_weights(weights)
2140    self._tpu_weights_initialized = False
2141
2142  def load_weights(self, filepath, by_name=False):
2143    self._cpu_model.load_weights(filepath, by_name)
2144    self._tpu_weights_initialized = False
2145
2146
2147# pylint: disable=bad-continuation
2148def _validate_shapes(model):
2149  """Validate that all layers in `model` have constant shape."""
2150  for layer in model.layers:
2151    if isinstance(layer.input_shape, tuple):
2152      input_shapes = [layer.input_shape]
2153    else:
2154      input_shapes = layer.input_shape
2155
2156    if isinstance(layer.output_shape, tuple):
2157      output_shapes = [layer.output_shape]
2158    else:
2159      output_shapes = layer.output_shape
2160
2161    for shape in input_shapes + output_shapes:
2162      for dim in shape[1:]:
2163        if dim is None:
2164          raise ValueError(
2165              """
2166Layer %(layer)s has a variable shape in a non-batch dimension.  TPU models must
2167have constant shapes for all operations.
2168
2169You may have to specify `input_length` for RNN/TimeDistributed layers.
2170
2171Layer: %(layer)s
2172Input shape: %(input_shape)s
2173Output shape: %(output_shape)s
2174  """ % {
2175          'layer': layer,
2176          'input_shape': layer.input_shape,
2177          'output_shape': layer.output_shape
2178          })
2179
2180
2181# pylint: enable=bad-continuation
2182
2183
2184@deprecated(
2185    '2019-02-20', 'Switch to tf.contrib.distribute.TPUStrategy. '
2186    'https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy'
2187)
2188def tpu_model(model, strategy=None):
2189  """Copy `model` along with weights to the TPU.
2190
2191  Returns a TPU model.
2192
2193  Usage:
2194  ```
2195  a = Input(shape=(32,))
2196  b = Dense(32)(a)
2197  model = Model(inputs=a, outputs=b)
2198
2199  # If `num_cores_per_host` is greater than one, batch parallelism will be used
2200  # to run on multiple TPU cores.
2201  strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
2202  model = keras_support.tpu_model(model, strategy)
2203  model.compile(
2204      optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0),
2205      ...)
2206  ```
2207
2208  Args:
2209    model: A `tf.keras.Model` instance.
2210    strategy: `TPUDistributionStrategy`.  The strategy to use for replicating
2211      model across multiple TPU cores.
2212
2213  Returns:
2214    A new `KerasTPUModel` instance.
2215  """
2216  _validate_shapes(model)
2217  # TODO(xiejw): Validate TPU model. TPUModel only?
2218  # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset?
2219  # TODO(xiejw): Adds reduction option.
2220
2221  if strategy is None:
2222    strategy = TPUDistributionStrategy()
2223  else:
2224    if not isinstance(strategy, TPUDistributionStrategy):
2225      raise TypeError(
2226          '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
2227          'Got: {}'.format(type(strategy)))
2228
2229  # If the model has already been initialized, grab the optimizer configuration
2230  # and model weights before entering the TPU session.
2231  if model.optimizer:
2232    if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not
2233        isinstance(model.optimizer, keras_optimizers.TFOptimizer)):
2234      optimizer_config = model.optimizer.get_config()
2235    else:
2236      optimizer_config = None
2237    model_weights = model.get_weights()
2238  else:
2239    model_weights = None
2240
2241  setup_tpu_session(strategy._tpu_cluster_resolver)
2242
2243  # Force initialization of the CPU model in the TPU session.
2244  cpu_model = models.clone_model(model)
2245  if model.optimizer:
2246    cpu_model.compile(
2247        _clone_optimizer(model.optimizer, optimizer_config),
2248        model.loss,
2249        metrics_module.clone_metrics(model._compile_metrics),
2250        model.loss_weights,
2251        model.sample_weight_mode,
2252        metrics_module.clone_metrics(model._compile_weighted_metrics),
2253    )
2254
2255  if model_weights:
2256    cpu_model.set_weights(model_weights)
2257    cpu_model.reset_states()
2258
2259  return KerasTPUModel(cpu_model=cpu_model, strategy=strategy)
2260