1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""*Experimental* support for running Keras models on the TPU. 16 17To use, wrap your model with the `keras_support.tpu_model` function. 18 19Example usage: 20 21``` 22image = tf.keras.layers.Input(shape=(28, 28, 3), name='image') 23c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image) 24flattened = tf.keras.layers.Flatten()(c1) 25logits = tf.keras.layers.Dense(10, activation='softmax')(flattened) 26model = tf.keras.Model(inputs=[image], outputs=[logits]) 27 28resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu=tpu_name) 29strategy = keras_support.TPUDistributionStrategy(resolver) 30model = keras_support.tpu_model(model, strategy=strategy) 31 32# Only TF optimizers are currently supported. 33model.compile(optimizer=tf.train.AdamOptimizer(), ...) 34 35# `images` and `labels` should be Numpy arrays. Support for tensor input 36# (e.g. datasets) is planned. 37model.fit(images, labels) 38``` 39""" 40 41# pylint: disable=protected-access 42 43from __future__ import absolute_import 44from __future__ import division 45from __future__ import print_function 46 47import abc 48import collections 49import contextlib 50import re 51import sys 52import time 53 54import numpy as np 55import six 56 57from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib 58from tensorflow.contrib.tpu.python.ops import tpu_ops 59from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables 60from tensorflow.contrib.tpu.python.tpu import tpu 61from tensorflow.contrib.tpu.python.tpu import tpu_function 62from tensorflow.contrib.tpu.python.tpu import tpu_optimizer 63from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib 64from tensorflow.core.protobuf import config_pb2 65from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result 66from tensorflow.python import tf2 67from tensorflow.python.client import session as tf_session 68from tensorflow.python.data.ops import dataset_ops 69from tensorflow.python.data.ops import iterator_ops 70from tensorflow.python.eager import context 71from tensorflow.python.estimator import model_fn as model_fn_lib 72from tensorflow.python.framework import constant_op 73from tensorflow.python.framework import dtypes 74from tensorflow.python.framework import errors 75from tensorflow.python.framework import ops 76from tensorflow.python.framework import tensor_shape 77from tensorflow.python.framework import tensor_spec 78from tensorflow.python.keras import backend as K 79from tensorflow.python.keras import callbacks as cbks 80from tensorflow.python.keras import metrics as metrics_module 81from tensorflow.python.keras import models 82from tensorflow.python.keras import optimizers as keras_optimizers 83from tensorflow.python.keras.engine import base_layer 84from tensorflow.python.keras.engine import base_layer_utils 85from tensorflow.python.keras.engine import training_arrays 86from tensorflow.python.keras.engine import training_utils 87from tensorflow.python.keras.layers import embeddings 88from tensorflow.python.keras.utils.generic_utils import make_batches 89from tensorflow.python.keras.utils.generic_utils import slice_arrays 90from tensorflow.python.ops import array_ops 91from tensorflow.python.ops import gen_linalg_ops 92from tensorflow.python.ops import math_ops 93from tensorflow.python.ops import random_ops 94from tensorflow.python.ops import variable_scope 95from tensorflow.python.ops import variables 96from tensorflow.python.platform import tf_logging as logging 97from tensorflow.python.util.deprecation import deprecated 98 99 100# TODO(b/114775106): temporary shim to optionally initialize the TPU 101# This increases the odds our session is initialized, but shouldn't be needed. 102_TEST_REWRITE_OP = None 103 104 105def _maybe_initialize_tpu(session): 106 """Initialize the TPU if it has not already been initialized.""" 107 global _TEST_REWRITE_OP 108 try: 109 # Try to use cached version to avoid another ground of graph optimization. 110 test_rewrite_op = _TEST_REWRITE_OP 111 if (test_rewrite_op is None or 112 test_rewrite_op[0].graph != ops.get_default_graph()): 113 114 def test_op(): 115 return constant_op.constant(1) + constant_op.constant(1) 116 117 test_rewrite_op = tpu.rewrite(test_op) 118 _TEST_REWRITE_OP = test_rewrite_op 119 120 session.run(test_rewrite_op) 121 except errors.FailedPreconditionError as _: 122 session.run(tpu.initialize_system()) 123 124 125@contextlib.contextmanager 126def _tpu_session_context(): 127 """Initialize the TPU and cleans cache entries for bad sessions.""" 128 try: 129 _maybe_initialize_tpu(K.get_session()) 130 yield 131 except (errors.FailedPreconditionError, errors.AbortedError) as e: 132 K.clear_session() 133 raise Exception(""" 134An error occurred connecting or initializing your TPU. 135 136The session has been reset. re-run keras_to_tpu_model to create a new session. 137""" + str(e)) 138 139 140def setup_tpu_session(cluster_resolver): 141 """Construct or return a `tf.Session` connected to the given cluster.""" 142 master = cluster_resolver.master() 143 144 # Use the existing session if we're already connected to this TPU 145 # N.B K.get_session() is a non-trivial operation, and may fail if the remote 146 # session has been reset. 147 try: 148 default_session = K.get_session() 149 if (default_session._target == master and 150 getattr(default_session, '_tpu_initialized', None)): 151 return 152 except errors.AbortedError as _: 153 # We lost the remote session and need to re-initialize. 154 logging.warning('Lost remote session: creating a new session.') 155 156 cluster_spec = cluster_resolver.cluster_spec() 157 config = config_pb2.ConfigProto(isolate_session_state=True) 158 if cluster_spec: 159 config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) 160 161 tpu_session = tf_session.Session(target=master, config=config) 162 tpu_session.run(tpu.initialize_system()) 163 tpu_session._tpu_initialized = True 164 165 # N.B. We have to call `K.set_session()` AND set our session as the 166 # TF default. `K.get_session()` surprisingly does not return the value 167 # supplied by K.set_session otherwise. 168 K.set_session(tpu_session) 169 170 171try: 172 from scipy.sparse import issparse # pylint: disable=g-import-not-at-top 173except ImportError: 174 issparse = None 175 176 177def get_tpu_system_metadata(tpu_cluster_resolver): 178 """Retrieves TPU system metadata given a TPUClusterResolver.""" 179 master = tpu_cluster_resolver.master() 180 181 # pylint: disable=protected-access 182 cluster_spec = tpu_cluster_resolver.cluster_spec() 183 cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None 184 tpu_system_metadata = ( 185 tpu_system_metadata_lib._query_tpu_system_metadata( 186 master, cluster_def=cluster_def, query_topology=False)) 187 188 return tpu_system_metadata 189 190 191class TPUDistributionStrategy(object): 192 """The strategy to run Keras model on TPU.""" 193 194 def __init__(self, tpu_cluster_resolver=None, using_single_core=False): 195 """Construct a TPUDistributionStrategy. 196 197 Args: 198 tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will 199 create one with '' as master address. 200 using_single_core: Bool. This is the debugging option, which might be 201 removed in future once the model replication functionality is mature 202 enough. If `False` (default behavior), the system automatically finds 203 the best configuration, in terms of number of TPU cores, for the model 204 replication, typically using all available TPU cores. If overwrites as 205 `True`, force the model replication using single core, i.e., no 206 replication. 207 Raises: 208 Exception: No TPU Found on the given worker. 209 """ 210 if tf2.enabled(): 211 raise RuntimeError( 212 'Keras support is now deprecated in support of TPU Strategy. ' 213 'Please follow the distribution strategy guide on tensorflow.org ' 214 'to migrate to the 2.0 supported version.') 215 else: 216 logging.warning( 217 'Keras support is now deprecated in support of TPU Strategy. ' 218 'Please follow the distribution strategy guide on tensorflow.org ' 219 'to migrate to the 2.0 supported version.') 220 if tpu_cluster_resolver is None: 221 tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('') 222 223 metadata = get_tpu_system_metadata(tpu_cluster_resolver) 224 self._tpu_metadata = metadata 225 self._tpu_cluster_resolver = tpu_cluster_resolver 226 self._num_cores = 1 if using_single_core else metadata.num_cores 227 228 # Walk device list to identify TPU worker for enqueue/dequeue operations. 229 worker_re = re.compile('/job:([^/]+)') 230 for device in metadata.devices: 231 if 'TPU:0' in device.name: 232 self._worker_name = worker_re.search(device.name).group(1) 233 return 234 raise Exception('No TPU found on given worker.') 235 236 def _make_assignment_for_model(self, cpu_model): 237 """Makes a `TPUAssignment` for the passed in `cpu_model`.""" 238 num_cores = self._num_cores 239 if num_cores > 1 and cpu_model.stateful: 240 logging.warning( 241 'Model replication does not currently support stateful models. ' 242 'Degrading to a single core.') 243 num_cores = 1 244 245 return TPUAssignment(worker_name=self._worker_name, num_cores=num_cores) 246 247 248class TPUAssignment(object): 249 """This is object holding TPU resources assignment for the concrete model. 250 251 `TPUDistributionStrategy` is responsible to create the instance of 252 `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on 253 model and input batch sizes. 254 """ 255 256 def __init__(self, worker_name, num_cores): 257 self._worker_name = worker_name 258 self._num_cores = num_cores 259 260 @property 261 def worker_name(self): 262 return self._worker_name 263 264 @property 265 def num_towers(self): 266 # TODO(xiejw): Support automatically assign num_cores based on inputs. 267 return self._num_cores 268 269 270class TPUEmbedding(embeddings.Embedding): 271 """TPU compatible embedding layer. 272 273 The default Keras layer is not TPU compatible. This layer is a drop-in 274 replacement: it has the same behavior and will work on CPU and GPU devices. 275 """ 276 277 def build(self, input_shape): 278 if input_shape[0] is None: 279 raise ValueError( 280 'TPUEmbeddings must have a fixed input_length or input shape.') 281 return super(TPUEmbedding, self).build(input_shape) 282 283 def call(self, inputs): 284 if K.dtype(inputs) != 'int32': 285 inputs = math_ops.cast(inputs, 'int32') 286 287 inputs = array_ops.one_hot(inputs, self.input_dim) 288 return math_ops.tensordot(inputs, self.embeddings, 1) 289 290 291def _cross_replica_concat(tensor, core_id, num_cores, name): 292 """Concatenate `tensor` across cores. 293 294 Args: 295 tensor: The tensor to be concatenated. Must be [int32 and float32]. 296 core_id: Tensor indicating the current TPU core. 297 num_cores: Python int. The total number of TPU cores in the system. 298 name: The string name to print for debugging. 299 300 Returns: 301 The same concatenated Tensor on each core. 302 """ 303 304 input_dtype = tensor.dtype 305 if input_dtype not in [dtypes.bfloat16, dtypes.float32, dtypes.int32]: 306 raise TypeError('For model replication, only (bfloat16, float32 and int32) ' 307 'is supported for model outputs and targets. Got {} for ' 308 '{}.'.format(input_dtype, name)) 309 310 batch_size = tensor.shape[0] 311 mask = math_ops.cast( 312 math_ops.equal(np.arange(num_cores, dtype=np.int32), core_id), 313 dtypes.float32) 314 mask = array_ops.reshape(mask, [num_cores] + [1] * tensor.shape.ndims) 315 result = mask * math_ops.cast(tensor, dtypes.float32) 316 local_tensor_with_holes = array_ops.reshape(result, 317 [-1] + result.shape.as_list()[2:]) 318 concat_tensor = tpu_ops.cross_replica_sum(local_tensor_with_holes) 319 concat_tensor.set_shape((num_cores * batch_size,) + tuple(tensor.shape[1:])) 320 321 if concat_tensor != input_dtype: 322 concat_tensor = math_ops.cast(concat_tensor, input_dtype) 323 return concat_tensor 324 325 326class KerasCrossShardOptimizer(keras_optimizers.Optimizer): 327 """An optimizer that averages gradients across TPU shards.""" 328 329 def __init__(self, opt, name='KerasCrossShardOptimizer'): 330 """Construct a new cross-shard optimizer. 331 332 Args: 333 opt: An existing `Optimizer` to encapsulate. 334 name: Optional name prefix for the operations created when applying 335 gradients. Defaults to "KerasCrossShardOptimizer". 336 337 Raises: 338 ValueError: If reduction is not a valid cross-shard reduction. 339 """ 340 super(KerasCrossShardOptimizer, self).__init__() 341 self._name = name 342 self._opt = opt 343 logging.info('KerasCrossShard: %s %s', self._opt, self._opt.weights) 344 345 def get_updates(self, loss, params): 346 self._opt.get_gradients = self.get_gradients 347 return self._opt.get_updates(loss, params) 348 349 def get_gradients(self, loss, params): 350 num_shards = tpu_function.get_tpu_context().number_of_shards 351 grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params) 352 return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads] 353 354 def get_weights(self): 355 return self._opt.get_weights() 356 357 def get_config(self): 358 return self._opt.get_config() 359 360 # Defer remaining operations to the underlying optimizer 361 def __getattr__(self, key): 362 return getattr(self._opt, key) 363 364 365class TPUModelOp( 366 collections.namedtuple('TPUModelOp', [ 367 'compile_op', 'execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op' 368 ])): 369 pass 370 371 372def _valid_name(tensor_name): 373 """Return a valid tensor name (strips '/', ':', etc).""" 374 return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name) 375 376 377def _replicated_optimizer(opt): 378 """Wrap the optimizer `opt` with CrossShardOptimizer if applicable.""" 379 # Always wrap `opt` with CrossShardOptimizer, even if we are running on a 380 # single core. This ensures Keras properly tracks and initializes optimizer 381 # variables. 382 if isinstance(opt, keras_optimizers.TFOptimizer): 383 return tpu_optimizer.CrossShardOptimizer(opt.optimizer) 384 else: 385 return KerasCrossShardOptimizer(opt) 386 387 388def _clone_optimizer(optimizer, config=None, worker_name=None): 389 """Returns a cloned optimizer with the provided optimizer.config or config.""" 390 if not isinstance(optimizer, keras_optimizers.Optimizer): 391 # In the first call to tpu_model(model), Keras may not have wrapped the TF 392 # optimizer in the TFOptimizer helper, e.g., the given model isn't compiled 393 # or optimizer isn't set, and later generated tpu_model compiles with a TF 394 # optimizer. 395 return optimizer 396 397 if isinstance(optimizer, keras_optimizers.TFOptimizer): 398 return keras_optimizers.TFOptimizer(optimizer.optimizer) 399 400 if config is None: 401 config = optimizer.get_config() 402 logging.info('Cloning %s %s', optimizer.__class__.__name__, config) 403 with ops.device( 404 '%s/device:CPU:0' % ('/job:%s' % worker_name if worker_name else '')): 405 # Explicitly put optimizer parameter variables on TPU worker. 406 return optimizer.__class__.from_config(config) 407 408 409class TPURewriteContext(object): 410 """Prepare the environment for a Keras model during `tpu.rewrite`. 411 412 This overrides the default placeholder behaviour to instead refer to a preset 413 input mapping. Placeholders are unsupported in TPU compiled code, and must 414 be replaced with explicit inputs or values from the infeed queue. 415 416 Instead of explicitly threading inputs all the way through the Keras codebase, 417 we override the behavior of the placeholder while compiling and inject the 418 Tensors from the infeed in place of the placeholder. 419 420 Similarly, as we compile a new sub-graph for each unique shape and execution 421 mode, we need to override the behavior of an embedded `name_scope` call in 422 the base Keras layer code. This allows us to re-use the same weights across 423 many compiles and share a single session/graph. 424 """ 425 426 def __init__(self, input_map): 427 self._input_map = input_map 428 self._default_placeholder = None 429 self._default_name_scope = None 430 431 def __enter__(self): 432 433 def _placeholder(dtype, shape=None, name=None): # pylint: disable=unused-argument 434 logging.info('Remapping placeholder for %s', name) 435 if name in self._input_map: 436 return self._input_map[name] 437 else: 438 logging.info('Default: %s', name) 439 return self._default_placeholder(dtype, shape, name) 440 441 def _name_scope(name, default_name=None, values=None): 442 caller_frame = sys._getframe().f_back 443 caller_obj = caller_frame.f_locals.get('self') 444 if (caller_obj is not None and 445 isinstance(caller_obj, base_layer.Layer) and name is not None): 446 return variable_scope.variable_scope( 447 name, default_name, values, reuse=variable_scope.AUTO_REUSE) 448 449 return self._default_name_scope(name, default_name, values) 450 451 self._default_placeholder = array_ops.placeholder 452 self._default_name_scope = ops.name_scope 453 self._default_make_variable = base_layer_utils.make_variable 454 self._default_random_normal = random_ops.random_normal 455 self._default_qr = gen_linalg_ops.qr 456 457 array_ops.placeholder = _placeholder 458 459 # Replace random_ops.random_normal with a dummy function because 460 # `random_normal` isn't yet implemented on the TPU. Because these 461 # initialized values are overwritten by the CPU values, this is okay. 462 def random_normal(shape, 463 mean=0.0, 464 stddev=1.0, 465 dtype=dtypes.float32, 466 seed=None, 467 name=None): 468 del mean 469 del stddev 470 del seed 471 return array_ops.zeros(shape, dtype=dtype, name=name) 472 473 random_ops.random_normal = random_normal 474 475 # Replace gen_linalg_ops.qr because QR decomposition is not yet implemented. 476 # TODO(saeta): Remove qr override once we confirm the qr implementation is 477 # ok. 478 # pylint: disable=redefined-builtin 479 def qr(input, full_matrices=False, name=None): 480 """Dummy implementation of qr decomposition.""" 481 del full_matrices # TODO(saeta): Properly handle the full matrix case. 482 input_shape = input.shape 483 if len(input_shape) < 2: 484 raise ValueError('Invalid shape passed to qr: %s' % input_shape) 485 p = min(input_shape[-1], input_shape[-2]) 486 if len(input_shape) == 2: 487 q = array_ops.zeros((p, p), name=name) 488 r = array_ops.zeros(input_shape, name=name) 489 return (r, q) 490 elif len(input_shape) == 3: 491 n = input_shape[0] 492 q = array_ops.zeros((n, p, p), name=name) 493 r = array_ops.zeros(input_shape, name=name) 494 return (r, q) 495 else: 496 raise ValueError('Invalid shape passed to qr: %s' % input_shape) 497 498 gen_linalg_ops.qr = qr 499 500 ops.name_scope = _name_scope 501 base_layer_utils.make_variable = variable_scope.get_variable 502 logging.info('Overriding default placeholder.') 503 return 504 505 def __exit__(self, exc_type, exc_val, exc_tb): 506 array_ops.placeholder = self._default_placeholder 507 ops.name_scope = self._default_name_scope 508 base_layer_utils.make_variable = self._default_make_variable 509 random_ops.random_normal = self._default_random_normal 510 gen_linalg_ops.qr = self._default_qr 511 512 513class SizedInfeed( 514 collections.namedtuple('SizedInfeed', 515 ['sharded_infeed_tensors', 'infeed_ops'])): 516 """Represents an instantiation of the infeed ops for a concrete input shape. 517 518 sharded_infeed_tensors: A data structure of Tensors used to represent the 519 placeholder tensors that must be fed when using feed_dicts. 520 521 infeed_ops: the set of ops that will be run to drive infeed for a single step. 522 """ 523 pass 524 525 526class TPUInfeedInstance(object): 527 """TPUInfeedInstance represents the logic to manage feeding in a single step. 528 529 See the comments on the `TPUInfeedManager` for a description for how infeed 530 is managed. 531 """ 532 533 @abc.abstractmethod 534 def make_input_specs(self, input_tensors): 535 """Constructs the infeed_specs for the given Infeed instance. 536 537 Args: 538 input_tensors: The inputs to the model. 539 540 Returns: 541 A list of 542 """ 543 pass 544 545 def make_feed_dict(self, tpu_model_op): 546 """Constructs a feed_dict for this instance, given the tpu_model_op. 547 548 Args: 549 tpu_model_op: A `TPUModelOp` representing the TPU Model for this 550 instance's input spec. 551 552 Returns: 553 A dictionary to use as the feed_dict of a `session.run` call. 554 """ 555 pass 556 557 558@six.add_metaclass(abc.ABCMeta) 559class TPUInfeedManager(object): 560 """TPUInfeedManager manages the data infeeding of data to a TPU computation. 561 562 Because there are multiple data sources (e.g. in-memory NumPy arrays, 563 `tf.data.Dataset`s), we abstract the different logic behind a single 564 interface: the `TPUInfeedManager`. 565 566 (1) A `TPUFunction` is called with a set of inputs. Based on the inputs, 567 `TPUFunction` retrieves the corresponding `TPUInfeedManager` (or constructs a 568 new one if required). 569 570 (2) The `TPUFunction` calls `make_infeed_instance` on the `TPUInfeedManager` 571 which returns a `TPUInfeedInstance`. 572 573 (3) The `TPUFunction` checks in the shape cache for a pre-compiled instance of 574 the model based on the returned `input_specs` from `TPUInfeedInstance`. 575 576 (4) [Optional.] If the model has not already been instantiated for the given 577 input spec, the `TPUFunction` compiles the model for the input spec (using the 578 `TPUInfeedManager`). 579 580 (5) The `TPUInfeedInstance` constructs the session.run's feed_dict given the 581 compiled model instance corresponding to its shape. 582 """ 583 584 @abc.abstractmethod 585 def make_infeed_instance(self, inputs): 586 """Given a single step's input, construct a `TPUInfeedInstance`. 587 588 Args: 589 inputs: The inputs to a given step. 590 591 Returns: 592 A subclass of `TPUInfeedInstance`. 593 """ 594 pass 595 596 @abc.abstractmethod 597 def build_infeed_from_input_specs(self, input_specs, execution_mode): 598 """For a given input specification (size, type), construct the infeed ops. 599 600 This is called only once for a given input specification and builds the 601 graph ops. It does not have a pointer to the actual infeed data. 602 603 Args: 604 input_specs: TODO(saeta): Document me! 605 execution_mode: TODO(saeta): Document me! 606 607 Returns: 608 A `SizedInfeed` instance. 609 """ 610 pass 611 612 613class TPUNumpyInfeedManager(TPUInfeedManager): 614 """TPU Infeed manager for Numpy inputs.""" 615 616 class NumpyInfeedInstance(TPUInfeedInstance): 617 """Infeed instance for Numpy inputs.""" 618 619 def __init__(self, sharded_inputs): 620 self._sharded_inputs = sharded_inputs 621 622 def make_input_specs(self, input_tensors): 623 # Compute an input specification (used to generate infeed enqueue and 624 # dequeue operations). We use the shape from our input array and the 625 # dtype from our model. A user may pass in a float64 for a float32 626 # input: for model compatibility we still must generate a float32 infeed. 627 input_specs = [] 628 # We use the shape and dtype from the first shard to compute the input 629 # metadata (`input_specs`); all replicas have the same type and shape. 630 for tensor, ary in zip(input_tensors, self._sharded_inputs[0]): 631 input_specs.append( 632 tensor_spec.TensorSpec(ary.shape, tensor.dtype, 633 _valid_name(tensor.name))) 634 635 return input_specs 636 637 def make_feed_dict(self, tpu_model_op): 638 infeed_dict = {} 639 for infeed_tensors, inputs in zip(tpu_model_op.infeed_tensors, 640 self._sharded_inputs): 641 for tensor, value in zip(infeed_tensors, inputs): 642 infeed_dict[tensor] = value 643 return infeed_dict 644 645 def __init__(self, tpu_assignment): 646 self._tpu_assignment = tpu_assignment 647 648 def _split_tensors(self, inputs): 649 """Split input data across shards. 650 651 Each input is sliced along the batch axis. 652 653 Args: 654 inputs: List of Numpy arrays to run on the TPU. 655 656 Returns: 657 List of lists containing the input to feed to each TPU shard. 658 """ 659 if self._tpu_assignment.num_towers == 1: 660 return [inputs] 661 662 batch_size = inputs[0].shape[0] 663 assert batch_size % self._tpu_assignment.num_towers == 0, ( 664 'batch_size must be divisible by the number of TPU cores in use (%s ' 665 'vs %s)' % (batch_size, self._tpu_assignment.num_towers)) 666 shard_size = batch_size // self._tpu_assignment.num_towers 667 input_list = [] 668 for index in range(self._tpu_assignment.num_towers): 669 shard_inputs = [ 670 x[index * shard_size:(index + 1) * shard_size] for x in inputs 671 ] 672 input_list.append(shard_inputs) 673 return input_list 674 675 def make_infeed_instance(self, inputs): 676 sharded_inputs = self._split_tensors(inputs) 677 return self.NumpyInfeedInstance(sharded_inputs) 678 679 def build_infeed_from_input_specs(self, input_specs, execution_mode): 680 infeed_op = [] 681 shard_infeed_tensors = [] 682 683 for shard_id in range(self._tpu_assignment.num_towers): 684 with ops.device( 685 '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): 686 infeed_tensors = [] 687 with ops.device('/device:TPU:%d' % shard_id): 688 for spec in input_specs: 689 # Construct placeholders for each of the inputs. 690 infeed_tensors.append( 691 array_ops.placeholder( 692 dtype=spec.dtype, 693 shape=spec.shape, 694 name='infeed-enqueue-%s-%d' % (spec.name, shard_id))) 695 shard_infeed_tensors.append(infeed_tensors) 696 697 infeed_op.append( 698 tpu_ops.infeed_enqueue_tuple( 699 infeed_tensors, [spec.shape for spec in input_specs], 700 name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), 701 device_ordinal=shard_id)) 702 return SizedInfeed( 703 infeed_ops=infeed_op, sharded_infeed_tensors=shard_infeed_tensors) 704 705 706class TPUDatasetInfeedManager(TPUInfeedManager): 707 """Manages infeed for a `tf.data.Dataset` into a TPU computation. 708 709 """ 710 711 class DatasetInfeedInstance(TPUInfeedInstance): 712 """An instance of the TPU infeed.""" 713 714 def __init__(self, input_specs): 715 self._input_specs = input_specs 716 717 def make_input_specs(self, input_tensors): 718 # TODO(saeta): Do error checking here! 719 return self._input_specs 720 721 def make_feed_dict(self, tpu_model_op): 722 # TODO(saeta): Verify tpu_model_op is as expected! 723 return {} 724 725 # pylint: disable=redefined-outer-name 726 def __init__(self, dataset, tpu_assignment, mode): 727 """Constructs a TPUDatasetInfeedManager. 728 729 Args: 730 dataset: A `tf.data.Dataset` to infeed. 731 tpu_assignment: The `TPUAssignment` used to configure the 732 Keras TPU model. 733 mode: ModeKeys enum. 734 """ 735 self._verify_dataset_shape(dataset) 736 737 self._dataset = dataset 738 self._tpu_assignment = tpu_assignment 739 dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) 740 dummy_x_shape = dataset_output_shapes[0].as_list() 741 dummy_x_shape[0] *= tpu_assignment.num_towers 742 dummy_y_shape = dataset_output_shapes[1].as_list() 743 dummy_y_shape[0] *= tpu_assignment.num_towers 744 self._iterator = dataset_ops.make_initializable_iterator(dataset) 745 K.get_session().run(self._iterator.initializer) 746 747 self._get_next_ops = [] 748 ctrl_deps = [] 749 for i in range(tpu_assignment.num_towers): 750 with ops.control_dependencies(ctrl_deps): # Ensure deterministic 751 # TODO(saeta): Ensure correct placement! 752 get_next_op = self._iterator.get_next() 753 self._get_next_ops.append(get_next_op) 754 ctrl_deps.extend(get_next_op) 755 756 # Use dummy numpy inputs for the rest of Keras' shape checking. We 757 # intercept them when building the model. 758 dataset_output_types = dataset_ops.get_legacy_output_types(dataset) 759 self._dummy_x = np.zeros( 760 dummy_x_shape, dtype=dataset_output_types[0].as_numpy_dtype) 761 self._dummy_y = np.zeros( 762 dummy_y_shape, dtype=dataset_output_types[1].as_numpy_dtype) 763 764 input_specs = [] 765 iterator_output_shapes = dataset_ops.get_legacy_output_shapes( 766 self._iterator) 767 iterator_output_types = dataset_ops.get_legacy_output_types(self._iterator) 768 if isinstance(iterator_output_shapes, tuple): 769 assert isinstance(iterator_output_types, tuple) 770 assert len(iterator_output_shapes) == len(iterator_output_types) 771 for i in range(len(iterator_output_shapes)): 772 spec = tensor_spec.TensorSpec(iterator_output_shapes[i], 773 iterator_output_types[i]) 774 input_specs.append(spec) 775 elif isinstance(iterator_output_shapes, tensor_shape.TensorShape): 776 spec = tensor_spec.TensorSpec(iterator_output_shapes, 777 iterator_output_types) 778 input_specs.append(spec) 779 780 # Pre-process the inputs and get_next_ops before caching. 781 input_specs, self._get_next_ops = ( 782 _inject_tpu_inputs_for_dataset( 783 tpu_assignment, mode, input_specs, self._get_next_ops)) 784 self._infeed_instance = self.DatasetInfeedInstance(input_specs) 785 786 def _verify_dataset_shape(self, dataset): 787 """Verifies a dataset is of an appropriate shape for TPUs.""" 788 dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset) 789 dataset_output_classes = dataset_ops.get_legacy_output_classes(dataset) 790 if not isinstance(dataset, dataset_ops.DatasetV2): 791 raise ValueError('The function passed as the `x` parameter did not ' 792 'return a `tf.data.Dataset`.') 793 if not isinstance(dataset_output_classes, tuple): 794 raise ValueError('The dataset must return a tuple of tf.Tensors, ' 795 'instead it returns: %s' % dataset_output_classes) 796 if len(dataset_output_classes) != 2: 797 raise ValueError('The dataset must return a 2-element tuple, got ' 798 '%s output classes instead.' % (dataset_output_classes,)) 799 for i, cls in enumerate(dataset_output_classes): 800 if cls != ops.Tensor: 801 raise ValueError('The dataset returned a non-Tensor type (%s) at ' 802 'index %d.' % (cls, i)) 803 for i, shape in enumerate(dataset_output_shapes): 804 if not shape: 805 raise ValueError('The dataset returns a scalar tensor in ' 806 'tuple index %d. Did you forget to batch? ' 807 '(Output shapes: %s).' % (i, dataset_output_shapes)) 808 for j, dim in enumerate(shape): 809 if dim.value is None: 810 if j == 0: 811 hint = (' Hint: did you use `ds.batch(BATCH_SIZE, ' 812 'drop_remainder=True)`?') 813 else: 814 hint = '' 815 raise ValueError( 816 'The Keras-TPU integration for `tf.data` ' 817 'currently requires static shapes. The provided ' 818 'dataset only has a partially defined shape. ' 819 '(Dimension %d of output tensor %d is not statically known ' 820 'for output shapes: %s.%s)' % (j, i, dataset_output_shapes, hint)) 821 822 @property 823 def dummy_x(self): 824 return self._dummy_x 825 826 @property 827 def dummy_y(self): 828 return self._dummy_y 829 830 def make_infeed_instance(self, inputs): 831 # TODO(saeta): Verify inputs is as expected. 832 return self._infeed_instance 833 834 def build_infeed_from_input_specs(self, input_specs, execution_mode): 835 shard_infeed_tensors = self._get_next_ops 836 assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers 837 infeed_ops = [] 838 for shard_id in range(self._tpu_assignment.num_towers): 839 with ops.device( 840 '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): 841 infeed_ops.append( 842 tpu_ops.infeed_enqueue_tuple( 843 shard_infeed_tensors[shard_id], 844 [spec.shape for spec in input_specs], 845 name='infeed-enqueue-%s-%d' % (execution_mode, shard_id), 846 device_ordinal=shard_id)) 847 return SizedInfeed( 848 infeed_ops=infeed_ops, sharded_infeed_tensors=shard_infeed_tensors) 849 850 851def _inject_tpu_inputs_for_dataset(tpu_assignment, mode, 852 input_specs, get_next_ops): 853 """Append core information to the set of dataset inputs.""" 854 # This is used during compilation to identify the current TPU core and enable 855 # concatenation operations across cores. 856 if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]: 857 return input_specs, get_next_ops 858 859 # Dataset inputs operate on per core basis. 860 per_core_batch_size = input_specs[0].shape.as_list()[0] 861 862 # Insert, at head, the tensor for core_id. 863 assert len(get_next_ops) == tpu_assignment.num_towers 864 for i in range(tpu_assignment.num_towers): 865 core_id_constant = constant_op.constant( 866 np.array([i] * per_core_batch_size).astype('int32'), 867 dtype=dtypes.int32, 868 name='cord_id_constant') 869 get_next_ops[i] = [core_id_constant] + list(get_next_ops[i]) 870 871 # Insert the input spec at head also. 872 input_specs = [tensor_spec.TensorSpec([per_core_batch_size], dtypes.int32) 873 ] + input_specs 874 875 return input_specs, get_next_ops 876 877 878def _inject_tpu_inputs_for_infeed(tpu_assignment, mode, 879 core_id_place_holder, input_tensors, inputs): 880 """Append core information to the set of inputs.""" 881 # This is used during compilation to identify the current TPU core and enable 882 # concatenation operations across cores. 883 if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]: 884 return input_tensors, inputs 885 886 # Puts a place holder in input spec. 887 input_tensors = [core_id_place_holder] + input_tensors 888 889 # Now fill the core id. For `num_cores` = 2, `batch_size` = 8, we fill the 890 # core id inputs as [0, 0, 0, 0, 1, 1, 1, 1], so each core sees its core id 891 # (duplicated). 892 num_cores = tpu_assignment.num_towers 893 per_core_batch_size = inputs[0].shape[0] // num_cores 894 core_ids = np.arange(num_cores).repeat(per_core_batch_size) 895 inputs = [core_ids] + inputs 896 return input_tensors, inputs 897 898 899def _read_tpu_coreid_from_infeed(mode, infeed_tensors): 900 """Popping out the core ids from infeed.""" 901 if mode not in [model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL]: 902 return None, infeed_tensors 903 904 if len(infeed_tensors) <= 1: 905 raise RuntimeError( 906 'The infeed tensors on TPU core has only {} tensors. ' 907 'This is not expected. Please report a bug.\nTensors: {}'.format( 908 len(infeed_tensors), infeed_tensors)) 909 910 core_id = infeed_tensors[0][0] # Pop out the scalar version. 911 rest = infeed_tensors[1:] 912 return core_id, rest 913 914 915class TPUFunction(object): 916 """K.function compatible interface for invoking a TPU compiled function. 917 918 Recompilation is triggered on-demand for each set of new inputs shapes: the 919 results are cached for future execution. We expect most computations will 920 be dominated by a standard batch-size, followed by a straggler batch for 921 the end of training or evaluation. 922 923 All `inputs` and `outputs` will be loaded via the infeed and outfeed queues 924 instead of being injected as `feed_dict` items or fetches. 925 """ 926 927 def __init__(self, model, execution_mode, tpu_assignment): 928 self.model = model 929 self.execution_mode = execution_mode 930 self._tpu_assignment = tpu_assignment 931 self._compilation_cache = {} 932 self._cloned_model = None 933 self._cloned_optimizer = None 934 # Create a placeholder for the TPU core ID. Cache the placeholder to avoid 935 # modifying the graph for every batch. 936 self._core_id_place_holder = array_ops.placeholder( 937 dtype=dtypes.int32, shape=[1], name='core_id') 938 939 def _specialize_model(self, input_specs, infeed_manager): 940 """Specialize `self.model` (a Keras model) for the given input shapes.""" 941 # Re-create our input and output layers inside our subgraph. They will be 942 # attached to the true computation when we clone our model in `tpu_fn`. 943 K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN) 944 945 # functools.partial and callable objects are not supported by tpu.rewrite 946 def _model_fn(): 947 """Compute fit/eval/predict for the TPU.""" 948 is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN 949 is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL 950 is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT 951 952 # During train/eval, we infeed our features as well as labels. 953 if is_training or is_test: 954 infeed_layers = self.model._input_layers + self.model._output_layers 955 else: 956 infeed_layers = self.model._input_layers 957 958 # Generate our infeed operation to read features & labels. 959 infeed_tensors = tpu_ops.infeed_dequeue_tuple( 960 dtypes=[spec.dtype for spec in input_specs], 961 shapes=[spec.shape for spec in input_specs], 962 name='infeed-%s' % self.execution_mode) 963 964 core_id, infeed_tensors = ( 965 _read_tpu_coreid_from_infeed( 966 mode=self.execution_mode, infeed_tensors=infeed_tensors)) 967 968 assert len(infeed_tensors) == len(infeed_layers), ( 969 'Infeed inputs did not match model: %s vs %s' % (infeed_layers, 970 infeed_tensors)) 971 972 tpu_targets = [] 973 tpu_input_map = {} 974 975 # Sort infeed outputs into inputs and labels for calling our Keras model. 976 for tensor, layer in zip(infeed_tensors, infeed_layers): 977 if layer in self.model._input_layers: 978 tpu_input_map[layer.name] = tensor 979 if layer in self.model._output_layers: 980 tpu_targets.append(tensor) 981 982 # Clone our CPU model, running within the TPU device context. 983 # 984 # We use the id of the original model as a key to avoid weight collisions 985 # (if a user re-runs the same model multiple times, in e.g. Colab). 986 with TPURewriteContext(tpu_input_map): 987 with variable_scope.variable_scope('tpu_%s' % id(self.model)): 988 with keras_tpu_variables.replicated_scope( 989 self._tpu_assignment.num_towers): 990 if not self._cloned_optimizer: 991 self._cloned_optimizer = _clone_optimizer( 992 self.model.cpu_optimizer, 993 worker_name=self._tpu_assignment.worker_name) 994 995 self._cloned_model = models.clone_model(self.model) 996 997 # When running on more than one core, concatenate outputs at the end 998 # of processing. In backprop stage, the gradients will be 999 # calculated according to the local inputs as gradient of 1000 # cross-replica-concat being zero for any outputs other than those 1001 # from mlocal core so the loss calculation is identical. 1002 num_towers = self.model._tpu_assignment.num_towers 1003 if num_towers > 1 and (is_training or is_test): 1004 new_outputs = [ 1005 _cross_replica_concat( 1006 o, core_id, num_towers, 1007 name='model output ({})'.format(o.name)) 1008 for o in self._cloned_model.outputs 1009 ] 1010 # Recast all low precision outputs back to float32 since we only 1011 # casted the inputs to bfloat16 and not targets. This is done so 1012 # that we can preserve precision when calculating the loss value. 1013 if new_outputs and new_outputs[0].dtype == dtypes.bfloat16: 1014 new_outputs = [ 1015 math_ops.cast(o, dtypes.float32) for o in new_outputs] 1016 self._cloned_model.outputs = new_outputs 1017 tpu_targets = [ 1018 _cross_replica_concat( 1019 tensor, 1020 core_id, 1021 num_towers, 1022 name='model target ({})'.format(tensor.name)) 1023 for tensor in tpu_targets 1024 ] 1025 1026 if is_training or is_test: 1027 with variable_scope.variable_scope( 1028 'metrics', reuse=variable_scope.AUTO_REUSE): 1029 self._cloned_model.compile( 1030 optimizer=_replicated_optimizer(self._cloned_optimizer), 1031 loss=self.model.loss, 1032 loss_weights=self.model.loss_weights, 1033 metrics=metrics_module.clone_metrics( 1034 self.model._compile_metrics), 1035 weighted_metrics=metrics_module.clone_metrics( 1036 self.model._compile_weighted_metrics), 1037 target_tensors=tpu_targets, 1038 ) 1039 1040 # Compute our outfeed depending on the execution mode 1041 if is_training: 1042 if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer): 1043 # For Keras optimizer, we try to place the variable weights on the TPU 1044 # device. Keras creates optimizer variables (e.g. momentum values for 1045 # the Momentum optimizer) when _make_train_function is invoked. 1046 with keras_tpu_variables.replicated_variable_for_optimizer( 1047 self._tpu_assignment.num_towers): 1048 self._cloned_model._make_train_function() 1049 else: 1050 self._cloned_model._make_train_function() 1051 1052 self._outfeed_spec = [ 1053 tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) 1054 for tensor in self._cloned_model.train_function.outputs 1055 ] 1056 return [ 1057 self._cloned_model.train_function.updates_op, 1058 tpu_ops.outfeed_enqueue_tuple( 1059 self._cloned_model.train_function.outputs, 1060 name='outfeed-enqueue-train') 1061 ] 1062 elif is_test: 1063 self._cloned_model._make_test_function() 1064 self._outfeed_spec = [ 1065 tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) 1066 for tensor in self._cloned_model.test_function.outputs 1067 ] 1068 return [ 1069 tpu_ops.outfeed_enqueue_tuple( 1070 self._cloned_model.test_function.outputs, 1071 name='outfeed-enqueue-test') 1072 ] 1073 elif is_predict: 1074 self._cloned_model._make_predict_function() 1075 self._outfeed_spec = [ 1076 tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name) 1077 for tensor in self._cloned_model.predict_function.outputs 1078 ] 1079 return [ 1080 tpu_ops.outfeed_enqueue_tuple( 1081 self._cloned_model.predict_function.outputs, 1082 name='outfeed-enqueue-predict', 1083 ) 1084 ] 1085 else: 1086 assert False, 'Unexpected execution mode: %s' % self.execution_mode 1087 1088 # Capture outfeed metadata computed during the rewrite. 1089 self._outfeed_spec = None 1090 1091 # Generate out TPU operations using `tpu.split_compile_and_replicate`. 1092 # `compile_op` can be used to test the TPU model compiles before execution. 1093 # `execute op` replicates `_model_fn` `num_replicas` times, with each shard 1094 # running on a different logical core. 1095 compile_op, execute_op = tpu.split_compile_and_replicate( 1096 _model_fn, inputs=[[] for _ in range(self._tpu_assignment.num_towers)]) 1097 1098 # Generate CPU side operations to enqueue features/labels and dequeue 1099 # outputs from the model call. 1100 sized_infeed = infeed_manager.build_infeed_from_input_specs( 1101 input_specs, self.execution_mode) 1102 # Build output ops. 1103 outfeed_op = [] 1104 for shard_id in range(self._tpu_assignment.num_towers): 1105 with ops.device( 1106 '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): 1107 outfeed_op.extend( 1108 tpu_ops.outfeed_dequeue_tuple( 1109 dtypes=[spec.dtype for spec in self._outfeed_spec], 1110 shapes=[spec.shape for spec in self._outfeed_spec], 1111 name='outfeed-dequeue-%s-%d' % (self.execution_mode, shard_id), 1112 device_ordinal=shard_id)) 1113 1114 return TPUModelOp( 1115 compile_op, 1116 execute_op, 1117 infeed_tensors=sized_infeed.sharded_infeed_tensors, 1118 infeed_op=sized_infeed.infeed_ops, 1119 outfeed_op=outfeed_op) 1120 1121 def _test_model_compiles(self, tpu_model_ops): 1122 """Verifies that the given TPUModelOp can be compiled via XLA.""" 1123 logging.info('Started compiling') 1124 start_time = time.time() 1125 1126 result = K.get_session().run(tpu_model_ops.compile_op) 1127 proto = tpu_compilation_result.CompilationResultProto() 1128 proto.ParseFromString(result) 1129 if proto.status_error_message: 1130 raise RuntimeError('Compilation failed: {}'.format( 1131 proto.status_error_message)) 1132 1133 end_time = time.time() 1134 logging.info('Finished compiling. Time elapsed: %s secs', 1135 end_time - start_time) 1136 1137 def _lookup_infeed_manager(self, inputs): 1138 """Return an existing manager, or construct a new InfeedManager for inputs. 1139 1140 _lookup_infeed_manager will return an existing InfeedManager if one has been 1141 previously assigned for this model and input. If not, it will construct a 1142 new TPUNumpyInfeedManager. 1143 1144 Args: 1145 inputs: A NumPy input to the model. 1146 1147 Returns: 1148 A `TPUInfeedManager` object to manage infeeds for this input. 1149 """ 1150 if inputs is None: 1151 return None 1152 1153 for x, mgr in self.model._numpy_to_infeed_manager_list: 1154 if inputs[0] is x: 1155 return mgr 1156 1157 return TPUNumpyInfeedManager(self.model._tpu_assignment) 1158 1159 def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager): 1160 """Looks up the corresponding `TPUModelOp` for a given `input_specs`. 1161 1162 It instantiates a new copy of the model for each unique input shape. 1163 1164 Args: 1165 input_specs: The specification of the inputs to train on. 1166 infeed_manager: The infeed manager responsible for feeding in data. 1167 1168 Returns: 1169 A `TPUModelOp` instance that can be used to execute a step of the model. 1170 """ 1171 if input_specs is None or infeed_manager is None: 1172 # Note: this condition is possible during the prologue or epilogue of the 1173 # pipelined loop. 1174 return None 1175 1176 # XLA requires every operation in the graph has a fixed shape. To 1177 # handle varying batch sizes we recompile a new sub-graph for each 1178 # unique input shape. 1179 shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs]) 1180 if shape_key not in self._compilation_cache: 1181 logging.info( 1182 'New input shapes; (re-)compiling: mode=%s ' 1183 '(# of cores %d), %s', self.execution_mode, 1184 self._tpu_assignment.num_towers, input_specs) 1185 new_tpu_model_ops = self._specialize_model(input_specs, 1186 infeed_manager) 1187 self._compilation_cache[shape_key] = new_tpu_model_ops 1188 self._test_model_compiles(new_tpu_model_ops) 1189 1190 return self._compilation_cache[shape_key] 1191 1192 def _construct_input_tensors_and_inputs(self, inputs): 1193 """Returns input tensors and numpy array inputs corresponding to `inputs`. 1194 1195 Args: 1196 inputs: NumPy inputs. 1197 1198 Returns: 1199 A tuple of `input_tensors`, and `inputs`. 1200 """ 1201 if inputs is None: 1202 # Note: this condition is possible during the prologue or epilogue of the 1203 # pipelined loop. 1204 return None, None 1205 1206 if isinstance(inputs[-1], int): 1207 # Remove the learning_phase flag at the end. We currently hard code the 1208 # learning_phase in TPUFunction. 1209 inputs = inputs[:-1] 1210 1211 if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or 1212 self.execution_mode == model_fn_lib.ModeKeys.EVAL): 1213 # Strip sample weight from inputs. 1214 input_tensors = self.model._feed_inputs + self.model._feed_targets 1215 else: 1216 input_tensors = self.model._feed_inputs 1217 1218 inputs = inputs[:len(input_tensors)] 1219 input_tensors, inputs = ( 1220 _inject_tpu_inputs_for_infeed( 1221 self._tpu_assignment, self.execution_mode, 1222 self._core_id_place_holder, input_tensors, inputs)) 1223 return input_tensors, inputs 1224 1225 def _process_outputs(self, outfeed_outputs): 1226 """Processes the outputs of a model function execution. 1227 1228 Args: 1229 outfeed_outputs: The sharded outputs of the TPU computation. 1230 1231 Returns: 1232 The aggregated outputs of the TPU computation to be used in the rest of 1233 the model execution. 1234 """ 1235 # TODO(xiejw): Decide how to reduce outputs, or discard all but first. 1236 if self.execution_mode == model_fn_lib.ModeKeys.PREDICT: 1237 outputs = [[] for _ in range(len(self._outfeed_spec))] 1238 outputs_per_replica = len(self._outfeed_spec) 1239 1240 for i in range(self._tpu_assignment.num_towers): 1241 output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) * 1242 outputs_per_replica] 1243 for j in range(outputs_per_replica): 1244 outputs[j].append(output_group[j]) 1245 1246 return [np.concatenate(group) for group in outputs] 1247 else: 1248 return outfeed_outputs[:len(outfeed_outputs) // 1249 self._tpu_assignment.num_towers] 1250 1251 def __call__(self, inputs): 1252 """__call__ executes the function on the computational hardware. 1253 1254 It handles executing infeed, and preprocessing in addition to executing the 1255 model on the TPU hardware. 1256 1257 Note: `__call__` has a sibling method `pipeline_run` which performs the same 1258 operations, but with software pipelining. 1259 1260 Args: 1261 inputs: The inputs to use to train. 1262 1263 Returns: 1264 The output of the computation for the given mode it is executed in. 1265 1266 Raises: 1267 RuntimeError: If there is an inappropriate use of the function. 1268 """ 1269 assert isinstance(inputs, list) 1270 1271 infeed_manager = self._lookup_infeed_manager(inputs) 1272 input_tensors, inputs = self._construct_input_tensors_and_inputs(inputs) 1273 infeed_instance = infeed_manager.make_infeed_instance(inputs) 1274 del inputs # To avoid accident usage. 1275 input_specs = infeed_instance.make_input_specs(input_tensors) 1276 tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs, 1277 infeed_manager) 1278 infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops) 1279 1280 # Initialize our TPU weights on the first compile. 1281 self.model._initialize_weights(self._cloned_model) 1282 1283 _, _, outfeed_outputs = K.get_session().run([ 1284 tpu_model_ops.infeed_op, tpu_model_ops.execute_op, 1285 tpu_model_ops.outfeed_op 1286 ], infeed_dict) 1287 return self._process_outputs(outfeed_outputs) 1288 1289 def pipeline_run(self, cur_step_inputs, next_step_inputs): 1290 """pipeline_run executes the function on the computational hardware. 1291 1292 pipeline_run performs the same computation as __call__, however it runs the 1293 infeed in a software pipelined fashion compared to the on-device execution. 1294 1295 Note: it is the responsibility of the caller to call `pipeline_run` in the 1296 following sequence: 1297 - Once with `cur_step_inputs=None` and `next_step_inputs=list(...)` 1298 - `n` times with `cur_step_inputs` and `next_step_inputs` as `list`s 1299 - Once with `cur_step_inputs=list(...)` and `next_step_inputs=None` 1300 Additionally, it is the responsibility of the caller to pass 1301 `next_step_inputs` as `cur_step_inputs` on the next invocation of 1302 `pipeline_run`. 1303 1304 Args: 1305 cur_step_inputs: The current step's inputs. 1306 next_step_inputs: The next step's inputs. 1307 1308 Returns: 1309 The output of the computation for the given mode it is executed in. 1310 1311 Raises: 1312 RuntimeError: If there is an inappropriate use of the function. 1313 """ 1314 # Software pipelined case. 1315 next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs) 1316 cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs) 1317 1318 if (next_step_infeed_manager is not None and 1319 cur_step_infeed_manager is not None): 1320 assert type(next_step_infeed_manager) is type(cur_step_infeed_manager) 1321 1322 next_input_tensors, next_step_inputs = ( 1323 self._construct_input_tensors_and_inputs(next_step_inputs)) 1324 cur_input_tensors, cur_step_inputs = ( 1325 self._construct_input_tensors_and_inputs(cur_step_inputs)) 1326 1327 cur_infeed_instance = None 1328 if cur_step_infeed_manager: 1329 cur_infeed_instance = cur_step_infeed_manager.make_infeed_instance( 1330 cur_step_inputs) 1331 next_infeed_instance = None 1332 if next_step_infeed_manager: 1333 next_infeed_instance = next_step_infeed_manager.make_infeed_instance( 1334 next_step_inputs) 1335 1336 del cur_step_inputs # Avoid accidental re-use. 1337 del next_step_inputs # Avoid accidental re-use. 1338 1339 cur_tpu_model_ops = None 1340 next_tpu_model_ops = None 1341 infeed_dict = None 1342 1343 if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager: 1344 cur_input_specs = cur_infeed_instance.make_input_specs(cur_input_tensors) 1345 cur_tpu_model_ops = self._tpu_model_ops_for_input_specs( 1346 cur_input_specs, cur_step_infeed_manager) 1347 1348 if (next_infeed_instance and next_input_tensors and 1349 next_step_infeed_manager): 1350 next_input_specs = next_infeed_instance.make_input_specs( 1351 next_input_tensors) 1352 next_tpu_model_ops = self._tpu_model_ops_for_input_specs( 1353 next_input_specs, next_step_infeed_manager) 1354 infeed_dict = next_infeed_instance.make_feed_dict(next_tpu_model_ops) 1355 1356 # Initialize our TPU weights on the first compile. 1357 self.model._initialize_weights(self._cloned_model) 1358 1359 if next_tpu_model_ops and cur_tpu_model_ops: 1360 _, _, outfeed_outputs = K.get_session().run([ 1361 next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op, 1362 cur_tpu_model_ops.outfeed_op 1363 ], infeed_dict) 1364 return self._process_outputs(outfeed_outputs) 1365 1366 if cur_tpu_model_ops: 1367 _, outfeed_outputs = K.get_session().run( 1368 [cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op]) 1369 return self._process_outputs(outfeed_outputs) 1370 1371 if next_tpu_model_ops: 1372 K.get_session().run(next_tpu_model_ops.infeed_op, infeed_dict) 1373 return None 1374 raise RuntimeError('Internal error: both current & next tpu_model_ops ' 1375 'were None') 1376 1377 1378class KerasTPUModel(models.Model): 1379 """TPU compatible Keras model wrapper.""" 1380 1381 def __init__(self, cpu_model, strategy): 1382 super(models.Model, self).__init__( # pylint: disable=bad-super-call 1383 inputs=cpu_model.inputs, 1384 outputs=cpu_model.outputs, 1385 name=cpu_model.name, 1386 ) 1387 if tf2.enabled(): 1388 raise RuntimeError( 1389 'Keras support is now deprecated in support of TPU Strategy. ' 1390 'Please follow the distribution strategy guide on tensorflow.org ' 1391 'to migrate to the 2.0 supported version.') 1392 else: 1393 logging.warning( 1394 'Keras support is now deprecated in support of TPU Strategy. ' 1395 'Please follow the distribution strategy guide on tensorflow.org ' 1396 'to migrate to the 2.0 supported version.') 1397 # Create a mapping from numpy arrays to infeed managers. 1398 # Note: uses a list of tuples instead of a map because numpy arrays are 1399 # not hashable. 1400 self._numpy_to_infeed_manager_list = [] 1401 1402 # Add distribution specific arguments since we don't call the Model init. 1403 self._distribution_strategy = None 1404 self._compile_distribution = None 1405 1406 self.predict_function = None 1407 self.test_function = None 1408 self.train_function = None 1409 self._stateful_metric_functions = [] 1410 1411 cluster_resolver = strategy._tpu_cluster_resolver 1412 self._tpu_name_or_address = cluster_resolver.get_master() 1413 self._cpu_model = cpu_model 1414 self._tpu_assignment = strategy._make_assignment_for_model(cpu_model) 1415 self._tpu_model = None 1416 self._tpu_weights_initialized = False 1417 1418 # If the input CPU model has already been compiled, compile our TPU model 1419 # immediately. 1420 if self._cpu_model.optimizer: 1421 self.compile( 1422 self._cpu_model.optimizer, 1423 self._cpu_model.loss, 1424 self._cpu_model._compile_metrics, 1425 self._cpu_model.loss_weights, 1426 self._cpu_model.sample_weight_mode, 1427 self._cpu_model._compile_weighted_metrics, 1428 self._cpu_model.target_tensors, 1429 ) 1430 1431 # This flag must be disabled upon model mutation, such as changing the model 1432 # layers or recompiling the model to use a different optimizer. New function 1433 # definitions are generated whenever this flag is disabled, ensuring that 1434 # internal graph functions are always using the current model structure. 1435 # 1436 # Requires declaration here because this constructor skips the 1437 # Model constructor. 1438 self._built_graph_functions = False 1439 1440 def get_config(self): 1441 return { 1442 'cpu_model': self._cpu_model, 1443 'tpu_name_or_address': self._tpu_name_or_address, 1444 'tpu_assignment': self._tpu_assignment, 1445 } 1446 1447 def compile(self, 1448 optimizer, 1449 loss=None, 1450 metrics=None, 1451 loss_weights=None, 1452 sample_weight_mode=None, 1453 weighted_metrics=None, 1454 target_tensors=None, 1455 **kwargs): 1456 if sample_weight_mode: 1457 raise ValueError('sample_weight_mode not supported for TPU execution.') 1458 if weighted_metrics: 1459 raise ValueError('weighted_metrics not supported for TPU execution.') 1460 if target_tensors: 1461 raise ValueError('target_tensors is not supported for TPU execution.') 1462 1463 self._cpu_model.compile( 1464 _clone_optimizer(optimizer), loss, 1465 metrics_module.clone_metrics(metrics), loss_weights, sample_weight_mode, 1466 metrics_module.clone_metrics(weighted_metrics), target_tensors, 1467 **kwargs) 1468 1469 super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights, 1470 sample_weight_mode, weighted_metrics, 1471 target_tensors, **kwargs) 1472 1473 def fit(self, 1474 x=None, 1475 y=None, 1476 batch_size=None, 1477 epochs=1, 1478 verbose=1, 1479 callbacks=None, 1480 validation_split=0., 1481 validation_data=None, 1482 shuffle=True, 1483 class_weight=None, 1484 sample_weight=None, 1485 initial_epoch=0, 1486 steps_per_epoch=None, 1487 validation_steps=None, 1488 **kwargs): 1489 if context.executing_eagerly(): 1490 raise EnvironmentError('KerasTPUModel currently does not support eager ' 1491 'mode.') 1492 1493 with _tpu_session_context(): 1494 assert not self._numpy_to_infeed_manager_list # Ensure empty. 1495 1496 infeed_managers = [] # Managers to clean up at the end of the fit call. 1497 if isinstance(x, dataset_ops.DatasetV2): 1498 # TODO(b/111413240): Support taking a tf.data.Dataset directly. 1499 raise ValueError( 1500 'Taking a Dataset directly is not yet supported. Please ' 1501 'wrap your dataset construction code in a function and ' 1502 'pass that to fit instead. For examples, see: ' 1503 'https://github.com/tensorflow/tpu/tree/master/models/experimental' 1504 '/keras') 1505 if callable(x): 1506 with ops.device( 1507 '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name): 1508 dataset = x() 1509 if steps_per_epoch is None: 1510 raise ValueError('When using tf.data as input to a model, you ' 1511 'should specify the steps_per_epoch argument.') 1512 if y is not None: 1513 raise ValueError('When using tf.data as input to a model, y must ' 1514 'be None') 1515 infeed_manager = TPUDatasetInfeedManager( 1516 dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN) 1517 # Use dummy numpy inputs for the rest of Keras' shape checking. We 1518 # intercept them when building the model. 1519 x = infeed_manager.dummy_x 1520 y = infeed_manager.dummy_y 1521 infeed_managers.append((x, infeed_manager)) 1522 1523 if isinstance(validation_data, dataset_ops.DatasetV2): 1524 # TODO(b/111413240): Support taking a tf.data.Dataset directly. 1525 raise ValueError( 1526 'Taking a Dataset directly is not yet supported. Please ' 1527 'wrap your dataset construction code in a function and ' 1528 'pass that to fit instead. For examples, see: ' 1529 'https://github.com/tensorflow/tpu/tree/master/models/experimental' 1530 '/keras') 1531 if callable(validation_data): 1532 dataset = validation_data() 1533 if validation_steps is None: 1534 raise ValueError('When using tf.data as validation for a model, you ' 1535 'should specify the validation_steps argument.') 1536 infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, 1537 model_fn_lib.ModeKeys.EVAL) 1538 # Use dummy numpy inputs for the rest of Keras' shape checking. We 1539 # intercept them when building the model. 1540 val_x = infeed_manager.dummy_x 1541 val_y = infeed_manager.dummy_y 1542 infeed_managers.append((val_x, infeed_manager)) 1543 validation_data = (val_x, val_y) 1544 1545 self._numpy_to_infeed_manager_list = infeed_managers 1546 try: 1547 pipeline = kwargs.get('_pipeline', True) 1548 if '_pipeline' in kwargs: 1549 kwargs.pop('_pipeline') 1550 if not pipeline: 1551 logging.info('Running non-pipelined training loop (`_pipeline=%s`).', 1552 pipeline) 1553 return super(KerasTPUModel, self).fit( 1554 x, y, batch_size, epochs, verbose, callbacks, validation_split, 1555 validation_data, shuffle, class_weight, sample_weight, 1556 initial_epoch, steps_per_epoch, validation_steps, **kwargs) 1557 return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks, 1558 validation_split, validation_data, shuffle, 1559 class_weight, sample_weight, initial_epoch, 1560 steps_per_epoch, validation_steps, **kwargs) 1561 finally: 1562 self._numpy_to_infeed_manager_list = [] 1563 1564 def evaluate(self, 1565 x=None, 1566 y=None, 1567 batch_size=None, 1568 verbose=1, 1569 sample_weight=None, 1570 steps=None): 1571 original_numpy_to_infeed_manager_list = [] 1572 if self._numpy_to_infeed_manager_list: 1573 # evaluate call may be executed as callbacks during the training. In this 1574 # case, _numpy_to_infeed_manager_list is not empty, so save it for 1575 # recovery at the end of evaluate call. 1576 original_numpy_to_infeed_manager_list = self._numpy_to_infeed_manager_list 1577 self._numpy_to_infeed_manager_list = [] 1578 1579 with _tpu_session_context(): 1580 # Managers to clean up at the end of the evaluate call. 1581 infeed_managers = [] 1582 if isinstance(x, dataset_ops.DatasetV2): 1583 # TODO(b/111413240): Support taking a tf.data.Dataset directly. 1584 raise ValueError( 1585 'Taking a Dataset directly is not yet supported. Please ' 1586 'wrap your dataset construction code in a function and ' 1587 'pass that to fit instead. For examples, see: ' 1588 'https://github.com/tensorflow/tpu/tree/master/models/experimental' 1589 '/keras') 1590 if callable(x): 1591 dataset = x() 1592 if steps is None: 1593 raise ValueError('When using tf.data as input to a model, you ' 1594 'should specify the steps argument.') 1595 if y is not None: 1596 raise ValueError('When using tf.data as input to a model, y must be ' 1597 'None') 1598 infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment, 1599 model_fn_lib.ModeKeys.EVAL) 1600 # Use dummy numpy inputs for the rest of Keras' shape checking. We 1601 # intercept them when building the model. 1602 x = infeed_manager.dummy_x 1603 y = infeed_manager.dummy_y 1604 infeed_managers.append((x, infeed_manager)) 1605 1606 self._numpy_to_infeed_manager_list = infeed_managers 1607 try: 1608 return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose, 1609 sample_weight, steps) 1610 finally: 1611 self._numpy_to_infeed_manager_list = ( 1612 original_numpy_to_infeed_manager_list) 1613 1614 def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks, 1615 validation_split, validation_data, shuffle, class_weight, 1616 sample_weight, initial_epoch, steps_per_epoch, 1617 validation_steps, **kwargs): 1618 # Similar to super.fit(...), but modified to support software pipelining. 1619 1620 # Backwards compatibility 1621 if batch_size is None and steps_per_epoch is None: 1622 batch_size = 32 1623 # Legacy support 1624 if 'nb_epoch' in kwargs: 1625 logging.warning('The `nb_epoch` argument in `fit` has been renamed ' 1626 '`epochs`.') 1627 epochs = kwargs.pop('nb_epoch') 1628 if kwargs: 1629 raise TypeError('Unrecognized keyword arguments: ' + str(kwargs)) 1630 1631 # Validate and standardize user data 1632 x, y, sample_weights = self._standardize_user_data( 1633 x, 1634 y, 1635 sample_weight=sample_weight, 1636 class_weight=class_weight, 1637 batch_size=batch_size, 1638 check_steps=True, 1639 steps_name='steps_per_epoch', 1640 steps=steps_per_epoch, 1641 validation_split=validation_split) 1642 1643 # Prepare validation data 1644 val_x, val_y, val_sample_weights = self._prepare_validation_data( 1645 validation_data, validation_split, validation_steps, x, y, 1646 sample_weights, batch_size) 1647 return self._pipeline_fit_loop( 1648 x, 1649 y, 1650 sample_weights=sample_weights, 1651 batch_size=batch_size, 1652 epochs=epochs, 1653 verbose=verbose, 1654 callbacks=callbacks, 1655 val_inputs=val_x, 1656 val_targets=val_y, 1657 val_sample_weights=val_sample_weights, 1658 shuffle=shuffle, 1659 initial_epoch=initial_epoch, 1660 steps_per_epoch=steps_per_epoch, 1661 validation_steps=validation_steps) 1662 1663 def _pipeline_fit_loop(self, 1664 inputs, 1665 targets, 1666 sample_weights, 1667 batch_size, 1668 epochs, 1669 verbose, 1670 callbacks, 1671 val_inputs, 1672 val_targets, 1673 val_sample_weights, 1674 shuffle, 1675 initial_epoch, 1676 steps_per_epoch, 1677 validation_steps): 1678 self._make_train_function() 1679 sample_weights = sample_weights or [] 1680 val_sample_weights = val_sample_weights or [] 1681 if not isinstance(K.learning_phase(), int): 1682 ins = inputs + targets + sample_weights + [1] 1683 else: 1684 ins = inputs + targets + sample_weights 1685 1686 do_validation = False 1687 if val_inputs: 1688 do_validation = True 1689 if (steps_per_epoch is None and verbose and inputs and 1690 hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')): 1691 print('Train on %d samples, validate on %d samples' % 1692 (inputs[0].shape[0], val_inputs[0].shape[0])) 1693 1694 if validation_steps: 1695 do_validation = True 1696 if steps_per_epoch is None: 1697 raise ValueError('Can only use `validation_steps` when doing step-wise ' 1698 'training, i.e. `steps_per_epoch` must be set.') 1699 1700 num_training_samples = training_utils.check_num_samples( 1701 ins, batch_size, steps_per_epoch, 'steps_per_epoch') 1702 count_mode = 'steps' if steps_per_epoch else 'samples' 1703 callbacks = cbks.configure_callbacks( 1704 callbacks, 1705 self, 1706 do_validation=do_validation, 1707 batch_size=batch_size, 1708 epochs=epochs, 1709 steps_per_epoch=steps_per_epoch, 1710 samples=num_training_samples, 1711 verbose=verbose, 1712 count_mode=count_mode) 1713 1714 if num_training_samples is not None: 1715 index_array = np.arange(num_training_samples) 1716 1717 # To prevent a slowdown, we find beforehand the arrays that need conversion. 1718 feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights 1719 indices_for_conversion_to_dense = [] 1720 for i in range(len(feed)): 1721 if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]): 1722 indices_for_conversion_to_dense.append(i) 1723 1724 callbacks.on_train_begin() 1725 for epoch in range(initial_epoch, epochs): 1726 # Reset stateful metrics 1727 for m in self.metrics: 1728 m.reset_states() 1729 # Update callbacks 1730 callbacks.on_epoch_begin(epoch) 1731 epoch_logs = {} 1732 if steps_per_epoch is not None: 1733 # Step-wise fit loop. 1734 self._pipeline_fit_loop_step_wise( 1735 ins=ins, 1736 callbacks=callbacks, 1737 steps_per_epoch=steps_per_epoch, 1738 epochs=epochs, 1739 do_validation=do_validation, 1740 val_inputs=val_inputs, 1741 val_targets=val_targets, 1742 val_sample_weights=val_sample_weights, 1743 validation_steps=validation_steps, 1744 epoch_logs=epoch_logs) 1745 else: 1746 # Sample-wise fit loop. 1747 self._pipeline_fit_loop_sample_wise( 1748 ins=ins, 1749 callbacks=callbacks, 1750 index_array=index_array, 1751 shuffle=shuffle, 1752 batch_size=batch_size, 1753 num_training_samples=num_training_samples, 1754 indices_for_conversion_to_dense=indices_for_conversion_to_dense, 1755 do_validation=do_validation, 1756 val_inputs=val_inputs, 1757 val_targets=val_targets, 1758 val_sample_weights=val_sample_weights, 1759 validation_steps=validation_steps, 1760 epoch_logs=epoch_logs) 1761 1762 callbacks.on_epoch_end(epoch, epoch_logs) 1763 if callbacks.model.stop_training: 1764 break 1765 callbacks.on_train_end() 1766 return self.history 1767 1768 def _pipeline_fit_loop_sample_wise(self, 1769 ins, 1770 callbacks, 1771 index_array, 1772 shuffle, 1773 batch_size, 1774 num_training_samples, 1775 indices_for_conversion_to_dense, 1776 do_validation, 1777 val_inputs, 1778 val_targets, 1779 val_sample_weights, 1780 validation_steps, 1781 epoch_logs): 1782 f = self.train_function 1783 if shuffle == 'batch': 1784 index_array = training_utils.batch_shuffle(index_array, batch_size) 1785 elif shuffle: 1786 np.random.shuffle(index_array) 1787 batches = make_batches(num_training_samples, batch_size) 1788 1789 ins_last_batch = None 1790 last_batch_logs = None 1791 batch_index = 0 1792 1793 for batch_index, (batch_start, batch_end) in enumerate(batches): 1794 batch_ids = index_array[batch_start:batch_end] 1795 try: 1796 if isinstance(ins[-1], int): 1797 # Do not slice the training phase flag. 1798 ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] 1799 else: 1800 ins_batch = slice_arrays(ins, batch_ids) 1801 except TypeError: 1802 raise TypeError('TypeError while preparing batch. If using HDF5 ' 1803 'input data, pass shuffle="batch".') 1804 1805 # Pipeline batch logs 1806 next_batch_logs = {} 1807 next_batch_logs['batch'] = batch_index 1808 next_batch_logs['size'] = len(batch_ids) 1809 if batch_index > 0: 1810 # Callbacks operate one step behind in software pipeline. 1811 callbacks.on_batch_begin(batch_index - 1, last_batch_logs) 1812 for i in indices_for_conversion_to_dense: 1813 ins_batch[i] = ins_batch[i].toarray() 1814 1815 outs = f.pipeline_run( 1816 cur_step_inputs=ins_last_batch, next_step_inputs=ins_batch) 1817 ins_last_batch = ins_batch 1818 1819 if batch_index == 0: 1820 assert outs is None 1821 else: 1822 if not isinstance(outs, list): 1823 outs = [outs] 1824 for l, o in zip(self.metrics_names, outs): 1825 last_batch_logs[l] = o # pylint: disable=unsupported-assignment-operation 1826 callbacks.on_batch_end(batch_index - 1, last_batch_logs) 1827 if callbacks.model.stop_training: 1828 return 1829 last_batch_logs = next_batch_logs 1830 1831 # Final batch 1832 callbacks.on_batch_begin(batch_index, last_batch_logs) 1833 outs = f.pipeline_run(cur_step_inputs=ins_last_batch, next_step_inputs=None) 1834 if not isinstance(outs, list): 1835 outs = [outs] 1836 for l, o in zip(self.metrics_names, outs): 1837 last_batch_logs[l] = o 1838 callbacks.on_batch_end(batch_index, last_batch_logs) 1839 if callbacks.model.stop_training: 1840 return 1841 1842 if do_validation: 1843 val_outs = training_arrays.test_loop( 1844 self, 1845 val_inputs, 1846 val_targets, 1847 sample_weights=val_sample_weights, 1848 batch_size=batch_size, 1849 steps=validation_steps, 1850 verbose=0) 1851 if not isinstance(val_outs, list): 1852 val_outs = [val_outs] 1853 # Same labels assumed. 1854 for l, o in zip(self.metrics_names, val_outs): 1855 epoch_logs['val_' + l] = o 1856 1857 def _pipeline_fit_loop_step_wise(self, 1858 ins, 1859 callbacks, 1860 steps_per_epoch, 1861 epochs, 1862 do_validation, 1863 val_inputs, 1864 val_targets, 1865 val_sample_weights, 1866 validation_steps, 1867 epoch_logs): 1868 f = self.train_function 1869 1870 # Loop prologue 1871 try: 1872 outs = f.pipeline_run(cur_step_inputs=None, next_step_inputs=ins) 1873 assert outs is None # Function shouldn't return anything! 1874 except errors.OutOfRangeError: 1875 logging.warning('Your dataset iterator ran out of data on the first step ' 1876 'of the epoch, preventing further training. Check to ' 1877 'make sure your paths are correct and you have ' 1878 'permissions to read the files. Skipping validation') 1879 1880 for step_index in range(steps_per_epoch): 1881 batch_logs = {'batch': step_index, 'size': 1} 1882 callbacks.on_batch_begin(step_index, batch_logs) 1883 try: 1884 if step_index < steps_per_epoch - 1: 1885 next_step_inputs = ins 1886 else: 1887 next_step_inputs = None 1888 outs = f.pipeline_run( 1889 cur_step_inputs=ins, next_step_inputs=next_step_inputs) 1890 except errors.OutOfRangeError: 1891 logging.warning('Your dataset iterator ran out of data; ' 1892 'interrupting training. Make sure that your ' 1893 'dataset can generate at least `steps_per_batch * ' 1894 'epochs` batches (in this case, %d batches). You ' 1895 'may need to use the repeat() function when ' 1896 'building your dataset.' % steps_per_epoch * epochs) 1897 break 1898 1899 if not isinstance(outs, list): 1900 outs = [outs] 1901 for l, o in zip(self.metrics_names, outs): 1902 batch_logs[l] = o 1903 1904 callbacks.on_batch_end(step_index, batch_logs) 1905 if callbacks.model.stop_training: 1906 break 1907 1908 if do_validation: 1909 val_outs = training_arrays.test_loop( 1910 self, 1911 val_inputs, 1912 val_targets, 1913 sample_weights=val_sample_weights, 1914 steps=validation_steps, 1915 verbose=0) 1916 if not isinstance(val_outs, list): 1917 val_outs = [val_outs] 1918 # Same labels assumed. 1919 for l, o in zip(self.metrics_names, val_outs): 1920 epoch_logs['val_' + l] = o 1921 1922 def _prepare_validation_data(self, validation_data, validation_split, 1923 validation_steps, x, y, sample_weights, 1924 batch_size): 1925 """Prepares the validation dataset. 1926 1927 Args: 1928 validation_data: The validation data (if provided) 1929 validation_split: The validation split (if provided) 1930 validation_steps: The validation steps (if provided) 1931 x: The main training data x (if provided) 1932 y: The main training data y (if provided) 1933 sample_weights: The sample weights (if provided) 1934 batch_size: The training batch size (if provided) 1935 1936 Returns: 1937 A 3-tuple of (val_x, val_y, val_sample_weights). 1938 1939 Raises: 1940 ValueError: If the provided arguments are not compatible with 1941 `KerasTPUModel`. 1942 """ 1943 # Note: this is similar to a section of $tf/python/keras/engine/training.py 1944 # It differns in that tf.data objects are not allowed to be passed directly. 1945 # Additionally, it handles validating shapes & types appropriately for use 1946 # in TPUs. 1947 if validation_data: 1948 if (isinstance(validation_data, iterator_ops.Iterator) or 1949 isinstance(validation_data, iterator_ops.EagerIterator) or 1950 isinstance(validation_data, dataset_ops.DatasetV2)): 1951 raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator ' 1952 'for validation_data. Please instead pass a function ' 1953 'that returns a `tf.data.Dataset`.') 1954 if len(validation_data) == 2: 1955 val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence 1956 val_sample_weight = None 1957 elif len(validation_data) == 3: 1958 val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence 1959 else: 1960 raise ValueError('When passing a `validation_data` argument, it must ' 1961 'contain either 2 items (x_val, y_val), or 3 items ' 1962 '(x_val, y_val, val_sample_weights). However we ' 1963 'received `validation_data=%s`' % validation_data) 1964 val_x, val_y, val_sample_weights = self._standardize_user_data( 1965 val_x, 1966 val_y, 1967 sample_weight=val_sample_weight, 1968 batch_size=batch_size, 1969 steps=validation_steps) 1970 elif validation_split and 0. < validation_split < 1.: 1971 if training_utils.has_symbolic_tensors(x): 1972 raise ValueError('If your data is in the form of symbolic tensors, you ' 1973 'cannot use `validation_split`.') 1974 if hasattr(x[0], 'shape'): 1975 split_at = int(x[0].shape[0] * (1. - validation_split)) 1976 else: 1977 split_at = int(len(x[0]) * (1. - validation_split)) 1978 1979 x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at)) 1980 y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) 1981 sample_weights, val_sample_weights = ( 1982 slice_arrays(sample_weights, 0, split_at), 1983 slice_arrays(sample_weights, split_at) 1984 ) 1985 elif validation_steps: 1986 val_x = [] 1987 val_y = [] 1988 val_sample_weights = [] 1989 else: 1990 val_x = None 1991 val_y = None 1992 val_sample_weights = None 1993 1994 return val_x, val_y, val_sample_weights 1995 1996 def predict(self, 1997 x, 1998 batch_size=None, 1999 verbose=0, 2000 steps=None, 2001 max_queue_size=10, 2002 workers=1, 2003 use_multiprocessing=False): 2004 with _tpu_session_context(): 2005 return super(KerasTPUModel, self).predict( 2006 x, 2007 batch_size=batch_size, 2008 verbose=verbose, 2009 steps=steps, 2010 max_queue_size=max_queue_size, 2011 workers=workers, 2012 use_multiprocessing=use_multiprocessing) 2013 2014 @property 2015 def optimizer(self): 2016 if self._tpu_model: 2017 return self._tpu_model.optimizer 2018 return self._cpu_model.optimizer 2019 2020 @optimizer.setter 2021 def optimizer(self, optimizer): 2022 self._optimizer = optimizer 2023 2024 @property 2025 def metrics(self): 2026 if self._tpu_model: 2027 return self._tpu_model.metrics 2028 return self._stateful_metric_functions 2029 2030 @metrics.setter 2031 def metrics(self, metrics): 2032 self._stateful_metric_functions = metrics 2033 2034 def _make_train_function(self): 2035 if not self.train_function: 2036 self.train_function = TPUFunction( 2037 self, 2038 model_fn_lib.ModeKeys.TRAIN, 2039 tpu_assignment=self._tpu_assignment) 2040 2041 return self.train_function 2042 2043 def _make_test_function(self): 2044 if not self.test_function: 2045 self.test_function = TPUFunction( 2046 self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment) 2047 return self.test_function 2048 2049 def _make_predict_function(self): 2050 if not self.predict_function: 2051 self.predict_function = TPUFunction( 2052 self, 2053 model_fn_lib.ModeKeys.PREDICT, 2054 tpu_assignment=self._tpu_assignment) 2055 return self.predict_function 2056 2057 def _initialize_weights(self, cloned_model): 2058 """Initialize TPU weights. 2059 2060 This is called on the first compile of the TPU model (first call to 2061 fit/predict/evaluate). 2062 2063 Args: 2064 cloned_model: `keras.Model`, TPU model to initialize. 2065 """ 2066 if self._tpu_weights_initialized: 2067 return 2068 2069 self._tpu_model = cloned_model 2070 self._tpu_weights_initialized = True 2071 2072 weights = self._cpu_model.get_weights() 2073 2074 if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer): 2075 cpu_optimizer_config = {} 2076 else: 2077 cpu_optimizer_config = self.cpu_optimizer.get_config() 2078 2079 logging.info('Setting weights on TPU model.') 2080 cloned_model.set_weights(weights) 2081 if self._tpu_model.optimizer is None: 2082 # tpu_model may not be compiled, e.g., loading weights and then predict. 2083 return 2084 for k, v in six.iteritems(cpu_optimizer_config): 2085 if k == 'name': 2086 continue 2087 opt_var = getattr(self._tpu_model.optimizer, k) 2088 if isinstance(opt_var, variables.Variable): 2089 logging.info('CPU -> TPU %s: %s {%s}', k, v, K.get_value(opt_var)) 2090 K.get_session().run(opt_var.assign(v)) 2091 else: 2092 logging.warning('Cannot update non-variable config: %s', k) 2093 2094 @property 2095 def cpu_optimizer(self): 2096 return self._cpu_model.optimizer 2097 2098 def sync_to_cpu(self): 2099 """Copy weights from the CPU, returning a synchronized CPU model.""" 2100 if not self._tpu_weights_initialized: 2101 return self._cpu_model 2102 2103 logging.info('Copying TPU weights to the CPU') 2104 tpu_weights = self._tpu_model.get_weights() 2105 2106 # TFOptimizers have no configurable options 2107 if isinstance(self.cpu_optimizer, keras_optimizers.TFOptimizer): 2108 tpu_optimizer_config = {} 2109 else: 2110 tpu_optimizer_config = self._tpu_model.optimizer.get_config() 2111 2112 self._cpu_model.set_weights(tpu_weights) 2113 for k, v in six.iteritems(tpu_optimizer_config): 2114 logging.info('TPU -> CPU %s: %s', k, v) 2115 if k == 'name': 2116 continue 2117 opt_var = getattr(self.cpu_optimizer, k) 2118 if isinstance(opt_var, variables.Variable): 2119 K.get_session().run(opt_var.assign(v)) 2120 else: 2121 logging.warning('Cannot update non-variable config: %s', k) 2122 2123 return self._cpu_model 2124 2125 def get_weights(self): 2126 return self.sync_to_cpu().get_weights() 2127 2128 def save_weights(self, *args, **kw): 2129 return self.sync_to_cpu().save_weights(*args, **kw) 2130 2131 def save(self, *args, **kw): 2132 return self.sync_to_cpu().save(*args, **kw) 2133 2134 def set_weights(self, weights): 2135 # We may not have a TPU model available if we haven't run fit/predict, so 2136 # we can't directly set the TPU weights here. 2137 # Instead, reset CPU model weights and force TPU re-initialization at the 2138 # next call. 2139 self._cpu_model.set_weights(weights) 2140 self._tpu_weights_initialized = False 2141 2142 def load_weights(self, filepath, by_name=False): 2143 self._cpu_model.load_weights(filepath, by_name) 2144 self._tpu_weights_initialized = False 2145 2146 2147# pylint: disable=bad-continuation 2148def _validate_shapes(model): 2149 """Validate that all layers in `model` have constant shape.""" 2150 for layer in model.layers: 2151 if isinstance(layer.input_shape, tuple): 2152 input_shapes = [layer.input_shape] 2153 else: 2154 input_shapes = layer.input_shape 2155 2156 if isinstance(layer.output_shape, tuple): 2157 output_shapes = [layer.output_shape] 2158 else: 2159 output_shapes = layer.output_shape 2160 2161 for shape in input_shapes + output_shapes: 2162 for dim in shape[1:]: 2163 if dim is None: 2164 raise ValueError( 2165 """ 2166Layer %(layer)s has a variable shape in a non-batch dimension. TPU models must 2167have constant shapes for all operations. 2168 2169You may have to specify `input_length` for RNN/TimeDistributed layers. 2170 2171Layer: %(layer)s 2172Input shape: %(input_shape)s 2173Output shape: %(output_shape)s 2174 """ % { 2175 'layer': layer, 2176 'input_shape': layer.input_shape, 2177 'output_shape': layer.output_shape 2178 }) 2179 2180 2181# pylint: enable=bad-continuation 2182 2183 2184@deprecated( 2185 '2019-02-20', 'Switch to tf.contrib.distribute.TPUStrategy. ' 2186 'https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy' 2187) 2188def tpu_model(model, strategy=None): 2189 """Copy `model` along with weights to the TPU. 2190 2191 Returns a TPU model. 2192 2193 Usage: 2194 ``` 2195 a = Input(shape=(32,)) 2196 b = Dense(32)(a) 2197 model = Model(inputs=a, outputs=b) 2198 2199 # If `num_cores_per_host` is greater than one, batch parallelism will be used 2200 # to run on multiple TPU cores. 2201 strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver) 2202 model = keras_support.tpu_model(model, strategy) 2203 model.compile( 2204 optimizer=tf.train.GradientDescentOptimizer(learning_rate=1.0), 2205 ...) 2206 ``` 2207 2208 Args: 2209 model: A `tf.keras.Model` instance. 2210 strategy: `TPUDistributionStrategy`. The strategy to use for replicating 2211 model across multiple TPU cores. 2212 2213 Returns: 2214 A new `KerasTPUModel` instance. 2215 """ 2216 _validate_shapes(model) 2217 # TODO(xiejw): Validate TPU model. TPUModel only? 2218 # TODO(xiejw): Validate replicas. Full or 1. Shall we allow subset? 2219 # TODO(xiejw): Adds reduction option. 2220 2221 if strategy is None: 2222 strategy = TPUDistributionStrategy() 2223 else: 2224 if not isinstance(strategy, TPUDistributionStrategy): 2225 raise TypeError( 2226 '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. ' 2227 'Got: {}'.format(type(strategy))) 2228 2229 # If the model has already been initialized, grab the optimizer configuration 2230 # and model weights before entering the TPU session. 2231 if model.optimizer: 2232 if (isinstance(model.optimizer, keras_optimizers.Optimizer) and not 2233 isinstance(model.optimizer, keras_optimizers.TFOptimizer)): 2234 optimizer_config = model.optimizer.get_config() 2235 else: 2236 optimizer_config = None 2237 model_weights = model.get_weights() 2238 else: 2239 model_weights = None 2240 2241 setup_tpu_session(strategy._tpu_cluster_resolver) 2242 2243 # Force initialization of the CPU model in the TPU session. 2244 cpu_model = models.clone_model(model) 2245 if model.optimizer: 2246 cpu_model.compile( 2247 _clone_optimizer(model.optimizer, optimizer_config), 2248 model.loss, 2249 metrics_module.clone_metrics(model._compile_metrics), 2250 model.loss_weights, 2251 model.sample_weight_mode, 2252 metrics_module.clone_metrics(model._compile_weighted_metrics), 2253 ) 2254 2255 if model_weights: 2256 cpu_model.set_weights(model_weights) 2257 cpu_model.reset_states() 2258 2259 return KerasTPUModel(cpu_model=cpu_model, strategy=strategy) 2260