1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Functions for saving and loading a Keras Model from HDF5 format. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import json 23import os 24 25import numpy as np 26from six.moves import zip # pylint: disable=redefined-builtin 27 28from tensorflow.python.keras import backend as K 29from tensorflow.python.keras import optimizers 30from tensorflow.python.keras.saving import model_config as model_config_lib 31from tensorflow.python.keras.utils import conv_utils 32from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.util import serialization 35from tensorflow.python.util.tf_export import keras_export 36 37# pylint: disable=g-import-not-at-top 38try: 39 import h5py 40 HDF5_OBJECT_HEADER_LIMIT = 64512 41except ImportError: 42 h5py = None 43# pylint: enable=g-import-not-at-top 44 45 46@keras_export('keras.models.save_model') 47def save_model(model, filepath, overwrite=True, include_optimizer=True): 48 """Saves a model to a HDF5 file. 49 50 The saved model contains: 51 - the model's configuration (topology) 52 - the model's weights 53 - the model's optimizer's state (if any) 54 55 Thus the saved model can be reinstantiated in 56 the exact same state, without any of the code 57 used for model definition or training. 58 59 Arguments: 60 model: Keras model instance to be saved. 61 filepath: One of the following: 62 - String, path where to save the model 63 - `h5py.File` object where to save the model 64 overwrite: Whether we should overwrite any existing 65 model at the target location, or instead 66 ask the user with a manual prompt. 67 include_optimizer: If True, save optimizer's state together. 68 69 Raises: 70 ImportError: if h5py is not available. 71 """ 72 73 if h5py is None: 74 raise ImportError('`save_model` requires h5py.') 75 76 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 77 78 # TODO(psv) Add warning when we save models that contain non-serializable 79 # entities like metrics added using `add_metric` and losses added using 80 # `add_loss.` 81 82 if not isinstance(filepath, h5py.File): 83 # If file exists and should not be overwritten. 84 if not overwrite and os.path.isfile(filepath): 85 proceed = ask_to_proceed_with_overwrite(filepath) 86 if not proceed: 87 return 88 89 f = h5py.File(filepath, mode='w') 90 opened_new_file = True 91 else: 92 f = filepath 93 opened_new_file = False 94 95 try: 96 f.attrs['keras_version'] = str(keras_version).encode('utf8') 97 f.attrs['backend'] = K.backend().encode('utf8') 98 f.attrs['model_config'] = json.dumps( 99 { 100 'class_name': model.__class__.__name__, 101 'config': model.get_config() 102 }, 103 default=serialization.get_json_type).encode('utf8') 104 105 model_weights_group = f.create_group('model_weights') 106 model_layers = model.layers 107 save_weights_to_hdf5_group(model_weights_group, model_layers) 108 109 if include_optimizer and model.optimizer: 110 if isinstance(model.optimizer, optimizers.TFOptimizer): 111 logging.warning( 112 'TensorFlow optimizers do not ' 113 'make it possible to access ' 114 'optimizer attributes or optimizer state ' 115 'after instantiation. ' 116 'As a result, we cannot save the optimizer ' 117 'as part of the model save file. ' 118 'You will have to compile your model again after loading it. ' 119 'Prefer using a Keras optimizer instead ' 120 '(see keras.io/optimizers).') 121 else: 122 f.attrs['training_config'] = json.dumps( 123 { 124 'optimizer_config': { 125 'class_name': model.optimizer.__class__.__name__, 126 'config': model.optimizer.get_config() 127 }, 128 'loss': model.loss, 129 'metrics': model._compile_metrics, 130 'weighted_metrics': model._compile_weighted_metrics, 131 'sample_weight_mode': model.sample_weight_mode, 132 'loss_weights': model.loss_weights, 133 }, 134 default=serialization.get_json_type).encode('utf8') 135 136 # Save optimizer weights. 137 save_optimizer_weights_to_hdf5_group(f, model.optimizer) 138 f.flush() 139 finally: 140 if opened_new_file: 141 f.close() 142 143 144@keras_export('keras.models.load_model') 145def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin 146 """Loads a model saved via `save_model`. 147 148 Arguments: 149 filepath: One of the following: 150 - String, path to the saved model 151 - `h5py.File` object from which to load the model 152 custom_objects: Optional dictionary mapping names 153 (strings) to custom classes or functions to be 154 considered during deserialization. 155 compile: Boolean, whether to compile the model 156 after loading. 157 158 Returns: 159 A Keras model instance. If an optimizer was found 160 as part of the saved model, the model is already 161 compiled. Otherwise, the model is uncompiled and 162 a warning will be displayed. When `compile` is set 163 to False, the compilation is omitted without any 164 warning. 165 166 Raises: 167 ImportError: if h5py is not available. 168 ValueError: In case of an invalid savefile. 169 """ 170 if h5py is None: 171 raise ImportError('`load_model` requires h5py.') 172 173 if not custom_objects: 174 custom_objects = {} 175 176 def convert_custom_objects(obj): 177 """Handles custom object lookup. 178 179 Arguments: 180 obj: object, dict, or list. 181 182 Returns: 183 The same structure, where occurrences 184 of a custom object name have been replaced 185 with the custom object. 186 """ 187 if isinstance(obj, list): 188 deserialized = [] 189 for value in obj: 190 deserialized.append(convert_custom_objects(value)) 191 return deserialized 192 if isinstance(obj, dict): 193 deserialized = {} 194 for key, value in obj.items(): 195 deserialized[key] = convert_custom_objects(value) 196 return deserialized 197 if obj in custom_objects: 198 return custom_objects[obj] 199 return obj 200 201 opened_new_file = not isinstance(filepath, h5py.File) 202 if opened_new_file: 203 f = h5py.File(filepath, mode='r') 204 else: 205 f = filepath 206 207 model = None 208 try: 209 # instantiate model 210 model_config = f.attrs.get('model_config') 211 if model_config is None: 212 raise ValueError('No model found in config file.') 213 model_config = json.loads(model_config.decode('utf-8')) 214 model = model_config_lib.model_from_config(model_config, 215 custom_objects=custom_objects) 216 217 # set weights 218 load_weights_from_hdf5_group(f['model_weights'], model.layers) 219 220 if compile: 221 # instantiate optimizer 222 training_config = f.attrs.get('training_config') 223 if training_config is None: 224 logging.warning('No training configuration found in save file: ' 225 'the model was *not* compiled. Compile it manually.') 226 return model 227 training_config = json.loads(training_config.decode('utf-8')) 228 optimizer_config = training_config['optimizer_config'] 229 optimizer = optimizers.deserialize( 230 optimizer_config, custom_objects=custom_objects) 231 232 # Recover loss functions and metrics. 233 loss = convert_custom_objects(training_config['loss']) 234 metrics = convert_custom_objects(training_config['metrics']) 235 weighted_metrics = convert_custom_objects( 236 training_config.get('weighted_metrics', None)) 237 sample_weight_mode = training_config['sample_weight_mode'] 238 loss_weights = training_config['loss_weights'] 239 240 # Compile model. 241 model.compile( 242 optimizer=optimizer, 243 loss=loss, 244 metrics=metrics, 245 weighted_metrics=weighted_metrics, 246 loss_weights=loss_weights, 247 sample_weight_mode=sample_weight_mode) 248 249 # Set optimizer weights. 250 if 'optimizer_weights' in f: 251 # Build train function (to get weight updates). 252 # Models that aren't graph networks must wait until they are called 253 # with data to _make_train_function() and so can't load optimizer 254 # weights. 255 if model._is_graph_network: # pylint: disable=protected-access 256 model._make_train_function() 257 optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) 258 try: 259 model.optimizer.set_weights(optimizer_weight_values) 260 except ValueError: 261 logging.warning('Error in loading the saved optimizer ' 262 'state. As a result, your model is ' 263 'starting with a freshly initialized ' 264 'optimizer.') 265 else: 266 logging.warning('Sequential models without an `input_shape` ' 267 'passed to the first layer cannot reload their ' 268 'optimizer state. As a result, your model is' 269 'starting with a freshly initialized optimizer.') 270 271 finally: 272 if opened_new_file: 273 f.close() 274 return model 275 276 277def preprocess_weights_for_loading(layer, 278 weights, 279 original_keras_version=None, 280 original_backend=None): 281 """Preprocess layer weights between different Keras formats. 282 283 Converts layers weights from Keras 1 format to Keras 2 and also weights of 284 CuDNN layers in Keras 2. 285 286 Arguments: 287 layer: Layer instance. 288 weights: List of weights values (Numpy arrays). 289 original_keras_version: Keras version for the weights, as a string. 290 original_backend: Keras backend the weights were trained with, 291 as a string. 292 293 Returns: 294 A list of weights values (Numpy arrays). 295 """ 296 def convert_nested_bidirectional(weights): 297 """Converts layers nested in `Bidirectional` wrapper. 298 299 This function uses `preprocess_weights_for_loading()` for converting 300 layers. 301 302 Arguments: 303 weights: List of weights values (Numpy arrays). 304 305 Returns: 306 A list of weights values (Numpy arrays). 307 """ 308 num_weights_per_layer = len(weights) // 2 309 forward_weights = preprocess_weights_for_loading( 310 layer.forward_layer, weights[:num_weights_per_layer], 311 original_keras_version, original_backend) 312 backward_weights = preprocess_weights_for_loading( 313 layer.backward_layer, weights[num_weights_per_layer:], 314 original_keras_version, original_backend) 315 return forward_weights + backward_weights 316 317 def convert_nested_time_distributed(weights): 318 """Converts layers nested in `TimeDistributed` wrapper. 319 320 This function uses `preprocess_weights_for_loading()` for converting nested 321 layers. 322 323 Arguments: 324 weights: List of weights values (Numpy arrays). 325 326 Returns: 327 A list of weights values (Numpy arrays). 328 """ 329 return preprocess_weights_for_loading( 330 layer.layer, weights, original_keras_version, original_backend) 331 332 def convert_nested_model(weights): 333 """Converts layers nested in `Model` or `Sequential`. 334 335 This function uses `preprocess_weights_for_loading()` for converting nested 336 layers. 337 338 Arguments: 339 weights: List of weights values (Numpy arrays). 340 341 Returns: 342 A list of weights values (Numpy arrays). 343 """ 344 new_weights = [] 345 # trainable weights 346 for sublayer in layer.layers: 347 num_weights = len(sublayer.trainable_weights) 348 if num_weights > 0: 349 new_weights.extend(preprocess_weights_for_loading( 350 layer=sublayer, 351 weights=weights[:num_weights], 352 original_keras_version=original_keras_version, 353 original_backend=original_backend)) 354 weights = weights[num_weights:] 355 356 # non-trainable weights 357 for sublayer in layer.layers: 358 num_weights = len([l for l in sublayer.weights 359 if l not in sublayer.trainable_weights]) 360 if num_weights > 0: 361 new_weights.extend(preprocess_weights_for_loading( 362 layer=sublayer, 363 weights=weights[:num_weights], 364 original_keras_version=original_keras_version, 365 original_backend=original_backend)) 366 weights = weights[num_weights:] 367 return new_weights 368 369 # Convert layers nested in Bidirectional/Model/Sequential. 370 # Both transformation should be ran for both Keras 1->2 conversion 371 # and for conversion of CuDNN layers. 372 if layer.__class__.__name__ == 'Bidirectional': 373 weights = convert_nested_bidirectional(weights) 374 if layer.__class__.__name__ == 'TimeDistributed': 375 weights = convert_nested_time_distributed(weights) 376 elif layer.__class__.__name__ in ['Model', 'Sequential']: 377 weights = convert_nested_model(weights) 378 379 if original_keras_version == '1': 380 if layer.__class__.__name__ == 'TimeDistributed': 381 weights = preprocess_weights_for_loading( 382 layer.layer, weights, original_keras_version, original_backend) 383 384 if layer.__class__.__name__ == 'Conv1D': 385 shape = weights[0].shape 386 # Handle Keras 1.1 format 387 if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: 388 # Legacy shape: 389 # (filters, input_dim, filter_length, 1) 390 assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], 391 1) 392 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 393 weights[0] = weights[0][:, 0, :, :] 394 395 if layer.__class__.__name__ == 'Conv2D': 396 if layer.data_format == 'channels_first': 397 # old: (filters, stack_size, kernel_rows, kernel_cols) 398 # new: (kernel_rows, kernel_cols, stack_size, filters) 399 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 400 401 if layer.__class__.__name__ == 'Conv2DTranspose': 402 if layer.data_format == 'channels_last': 403 # old: (kernel_rows, kernel_cols, stack_size, filters) 404 # new: (kernel_rows, kernel_cols, filters, stack_size) 405 weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) 406 if layer.data_format == 'channels_first': 407 # old: (filters, stack_size, kernel_rows, kernel_cols) 408 # new: (kernel_rows, kernel_cols, filters, stack_size) 409 weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) 410 411 if layer.__class__.__name__ == 'Conv3D': 412 if layer.data_format == 'channels_first': 413 # old: (filters, stack_size, ...) 414 # new: (..., stack_size, filters) 415 weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) 416 417 if layer.__class__.__name__ == 'GRU': 418 if len(weights) == 9: 419 kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) 420 recurrent_kernel = np.concatenate( 421 [weights[1], weights[4], weights[7]], axis=-1) 422 bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) 423 weights = [kernel, recurrent_kernel, bias] 424 425 if layer.__class__.__name__ == 'LSTM': 426 if len(weights) == 12: 427 # old: i, c, f, o 428 # new: i, f, c, o 429 kernel = np.concatenate( 430 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 431 recurrent_kernel = np.concatenate( 432 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 433 bias = np.concatenate( 434 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 435 weights = [kernel, recurrent_kernel, bias] 436 437 if layer.__class__.__name__ == 'ConvLSTM2D': 438 if len(weights) == 12: 439 kernel = np.concatenate( 440 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 441 recurrent_kernel = np.concatenate( 442 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 443 bias = np.concatenate( 444 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 445 if layer.data_format == 'channels_first': 446 # old: (filters, stack_size, kernel_rows, kernel_cols) 447 # new: (kernel_rows, kernel_cols, stack_size, filters) 448 kernel = np.transpose(kernel, (2, 3, 1, 0)) 449 recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) 450 weights = [kernel, recurrent_kernel, bias] 451 452 conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] 453 if layer.__class__.__name__ in conv_layers: 454 if original_backend == 'theano': 455 weights[0] = conv_utils.convert_kernel(weights[0]) 456 if layer.__class__.__name__ == 'ConvLSTM2D': 457 weights[1] = conv_utils.convert_kernel(weights[1]) 458 if K.int_shape(layer.weights[0]) != weights[0].shape: 459 weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) 460 if layer.__class__.__name__ == 'ConvLSTM2D': 461 weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) 462 463 # convert CuDNN layers 464 return _convert_rnn_weights(layer, weights) 465 466 467def _convert_rnn_weights(layer, weights): 468 """Converts weights for RNN layers between native and CuDNN format. 469 470 Input kernels for each gate are transposed and converted between Fortran 471 and C layout, recurrent kernels are transposed. For LSTM biases are summed/ 472 split in half, for GRU biases are reshaped. 473 474 Weights can be converted in both directions between `LSTM` and`CuDNNSLTM` 475 and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not 476 compatible with `CuDNNGRU`. 477 478 For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made. 479 480 Arguments: 481 layer: Target layer instance. 482 weights: List of source weights values (input kernels, recurrent 483 kernels, [biases]) (Numpy arrays). 484 485 Returns: 486 A list of converted weights values (Numpy arrays). 487 488 Raises: 489 ValueError: for incompatible GRU layer/weights or incompatible biases 490 """ 491 492 def transform_kernels(kernels, func, n_gates): 493 """Transforms kernel for each gate separately using given function. 494 495 Arguments: 496 kernels: Stacked array of kernels for individual gates. 497 func: Function applied to kernel of each gate. 498 n_gates: Number of gates (4 for LSTM, 3 for GRU). 499 500 Returns: 501 Stacked array of transformed kernels. 502 """ 503 return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)]) 504 505 def transpose_input(from_cudnn): 506 """Makes a function that transforms input kernels from/to CuDNN format. 507 508 It keeps the shape, but changes between the layout (Fortran/C). Eg.: 509 510 ``` 511 Keras CuDNN 512 [[0, 1, 2], <---> [[0, 2, 4], 513 [3, 4, 5]] [1, 3, 5]] 514 ``` 515 516 It can be passed to `transform_kernels()`. 517 518 Arguments: 519 from_cudnn: `True` if source weights are in CuDNN format, `False` 520 if they're in plain Keras format. 521 522 Returns: 523 Function that converts input kernel to the other format. 524 """ 525 order = 'F' if from_cudnn else 'C' 526 527 def transform(kernel): 528 return kernel.T.reshape(kernel.shape, order=order) 529 530 return transform 531 532 target_class = layer.__class__.__name__ 533 534 # convert the weights between CuDNNLSTM and LSTM 535 if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3: 536 # determine if we're loading a CuDNNLSTM layer 537 # from the number of bias weights: 538 # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) 539 # if there's no bias weight in the file, skip this conversion 540 units = weights[1].shape[0] 541 bias_shape = weights[2].shape 542 n_gates = 4 543 544 if bias_shape == (2 * units * n_gates,): 545 source = 'CuDNNLSTM' 546 elif bias_shape == (units * n_gates,): 547 source = 'LSTM' 548 else: 549 raise ValueError('Invalid bias shape: ' + str(bias_shape)) 550 551 def convert_lstm_weights(weights, from_cudnn=True): 552 """Converts the weights between CuDNNLSTM and LSTM. 553 554 Arguments: 555 weights: Original weights. 556 from_cudnn: Indicates whether original weights are from CuDNN layer. 557 558 Returns: 559 Updated weights compatible with LSTM. 560 """ 561 562 # Transpose (and reshape) input and recurrent kernels 563 kernels = transform_kernels(weights[0], transpose_input(from_cudnn), 564 n_gates) 565 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) 566 if from_cudnn: 567 # merge input and recurrent biases into a single set 568 biases = np.sum(np.split(weights[2], 2, axis=0), axis=0) 569 else: 570 # Split single set of biases evenly to two sets. The way of 571 # splitting doesn't matter as long as the two sets sum is kept. 572 biases = np.tile(0.5 * weights[2], 2) 573 return [kernels, recurrent_kernels, biases] 574 575 if source != target_class: 576 weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM') 577 578 # convert the weights between CuDNNGRU and GRU(reset_after=True) 579 if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3: 580 # We can determine the source of the weights from the shape of the bias. 581 # If there is no bias we skip the conversion since 582 # CuDNNGRU always has biases. 583 584 units = weights[1].shape[0] 585 bias_shape = weights[2].shape 586 n_gates = 3 587 588 def convert_gru_weights(weights, from_cudnn=True): 589 """Converts the weights between CuDNNGRU and GRU. 590 591 Arguments: 592 weights: Original weights. 593 from_cudnn: Indicates whether original weights are from CuDNN layer. 594 595 Returns: 596 Updated weights compatible with GRU. 597 """ 598 599 kernels = transform_kernels(weights[0], transpose_input(from_cudnn), 600 n_gates) 601 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) 602 biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1) 603 return [kernels, recurrent_kernels, biases] 604 605 if bias_shape == (2 * units * n_gates,): 606 source = 'CuDNNGRU' 607 elif bias_shape == (2, units * n_gates): 608 source = 'GRU(reset_after=True)' 609 elif bias_shape == (units * n_gates,): 610 source = 'GRU(reset_after=False)' 611 else: 612 raise ValueError('Invalid bias shape: ' + str(bias_shape)) 613 614 if target_class == 'CuDNNGRU': 615 target = 'CuDNNGRU' 616 elif layer.reset_after: 617 target = 'GRU(reset_after=True)' 618 else: 619 target = 'GRU(reset_after=False)' 620 621 # only convert between different types 622 if source != target: 623 types = (source, target) 624 if 'GRU(reset_after=False)' in types: 625 raise ValueError('%s is not compatible with %s' % types) 626 if source == 'CuDNNGRU': 627 weights = convert_gru_weights(weights, from_cudnn=True) 628 elif source == 'GRU(reset_after=True)': 629 weights = convert_gru_weights(weights, from_cudnn=False) 630 631 return weights 632 633 634def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer): 635 """Saves optimizer weights of a optimizer to a HDF5 group. 636 637 Arguments: 638 hdf5_group: HDF5 group. 639 optimizer: optimizer instance. 640 """ 641 642 symbolic_weights = getattr(optimizer, 'weights') 643 if symbolic_weights: 644 weights_group = hdf5_group.create_group('optimizer_weights') 645 weight_names = [str(w.name).encode('utf8') for w in symbolic_weights] 646 save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names) 647 weight_values = K.batch_get_value(symbolic_weights) 648 for name, val in zip(weight_names, weight_values): 649 param_dset = weights_group.create_dataset( 650 name, val.shape, dtype=val.dtype) 651 if not val.shape: 652 # scalar 653 param_dset[()] = val 654 else: 655 param_dset[:] = val 656 657 658def load_optimizer_weights_from_hdf5_group(hdf5_group): 659 """Load optimizer weights from a HDF5 group. 660 661 Arguments: 662 hdf5_group: A pointer to a HDF5 group. 663 664 Returns: 665 data: List of optimizer weight names. 666 """ 667 weights_group = hdf5_group['optimizer_weights'] 668 optimizer_weight_names = load_attributes_from_hdf5_group( 669 weights_group, 'weight_names') 670 return [weights_group[weight_name] for weight_name in optimizer_weight_names] 671 672 673def save_weights_to_hdf5_group(f, layers): 674 """Saves the weights of a list of layers to a HDF5 group. 675 676 Arguments: 677 f: HDF5 group. 678 layers: List of layer instances. 679 """ 680 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 681 682 save_attributes_to_hdf5_group( 683 f, 'layer_names', [layer.name.encode('utf8') for layer in layers]) 684 f.attrs['backend'] = K.backend().encode('utf8') 685 f.attrs['keras_version'] = str(keras_version).encode('utf8') 686 687 for layer in layers: 688 g = f.create_group(layer.name) 689 weight_values = K.batch_get_value(layer.weights) 690 weight_names = [w.name.encode('utf8') for w in layer.weights] 691 save_attributes_to_hdf5_group(g, 'weight_names', weight_names) 692 for name, val in zip(weight_names, weight_values): 693 param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) 694 if not val.shape: 695 # scalar 696 param_dset[()] = val 697 else: 698 param_dset[:] = val 699 700 701def load_weights_from_hdf5_group(f, layers): 702 """Implements topological (order-based) weight loading. 703 704 Arguments: 705 f: A pointer to a HDF5 group. 706 layers: a list of target layers. 707 708 Raises: 709 ValueError: in case of mismatch between provided layers 710 and weights file. 711 """ 712 if 'keras_version' in f.attrs: 713 original_keras_version = f.attrs['keras_version'].decode('utf8') 714 else: 715 original_keras_version = '1' 716 if 'backend' in f.attrs: 717 original_backend = f.attrs['backend'].decode('utf8') 718 else: 719 original_backend = None 720 721 filtered_layers = [] 722 for layer in layers: 723 weights = layer.weights 724 if weights: 725 filtered_layers.append(layer) 726 727 layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 728 filtered_layer_names = [] 729 for name in layer_names: 730 g = f[name] 731 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 732 if weight_names: 733 filtered_layer_names.append(name) 734 layer_names = filtered_layer_names 735 if len(layer_names) != len(filtered_layers): 736 raise ValueError('You are trying to load a weight file ' 737 'containing ' + str(len(layer_names)) + 738 ' layers into a model with ' + str(len(filtered_layers)) + 739 ' layers.') 740 741 # We batch weight value assignments in a single backend call 742 # which provides a speedup in TensorFlow. 743 weight_value_tuples = [] 744 for k, name in enumerate(layer_names): 745 g = f[name] 746 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 747 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 748 layer = filtered_layers[k] 749 symbolic_weights = layer.weights 750 weight_values = preprocess_weights_for_loading( 751 layer, weight_values, original_keras_version, original_backend) 752 if len(weight_values) != len(symbolic_weights): 753 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 754 '" in the current model) was found to ' 755 'correspond to layer ' + name + ' in the save file. ' 756 'However the new layer ' + layer.name + ' expects ' + 757 str(len(symbolic_weights)) + 758 ' weights, but the saved weights have ' + 759 str(len(weight_values)) + ' elements.') 760 weight_value_tuples += zip(symbolic_weights, weight_values) 761 K.batch_set_value(weight_value_tuples) 762 763 764def load_weights_from_hdf5_group_by_name(f, layers): 765 """Implements name-based weight loading. 766 767 (instead of topological weight loading). 768 769 Layers that have no matching name are skipped. 770 771 Arguments: 772 f: A pointer to a HDF5 group. 773 layers: a list of target layers. 774 775 Raises: 776 ValueError: in case of mismatch between provided layers 777 and weights file. 778 """ 779 if 'keras_version' in f.attrs: 780 original_keras_version = f.attrs['keras_version'].decode('utf8') 781 else: 782 original_keras_version = '1' 783 if 'backend' in f.attrs: 784 original_backend = f.attrs['backend'].decode('utf8') 785 else: 786 original_backend = None 787 788 # New file format. 789 layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 790 791 # Reverse index of layer name to list of layers with name. 792 index = {} 793 for layer in layers: 794 if layer.name: 795 index.setdefault(layer.name, []).append(layer) 796 797 # We batch weight value assignments in a single backend call 798 # which provides a speedup in TensorFlow. 799 weight_value_tuples = [] 800 for k, name in enumerate(layer_names): 801 g = f[name] 802 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 803 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 804 805 for layer in index.get(name, []): 806 symbolic_weights = layer.weights 807 weight_values = preprocess_weights_for_loading( 808 layer, weight_values, original_keras_version, original_backend) 809 if len(weight_values) != len(symbolic_weights): 810 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 811 '") expects ' + str(len(symbolic_weights)) + 812 ' weight(s), but the saved weights' + ' have ' + 813 str(len(weight_values)) + ' element(s).') 814 # Set values. 815 for i in range(len(weight_values)): 816 if K.int_shape(symbolic_weights[i]) != weight_values[i].shape: 817 raise ValueError('Layer #' + str(k) +' (named "' + layer.name + 818 '"), weight ' + str(symbolic_weights[i]) + 819 ' has shape {}'.format(K.int_shape( 820 symbolic_weights[i])) + 821 ', but the saved weight has shape ' + 822 str(weight_values[i].shape) + '.') 823 824 else: 825 weight_value_tuples.append((symbolic_weights[i], weight_values[i])) 826 K.batch_set_value(weight_value_tuples) 827 828 829def save_attributes_to_hdf5_group(group, name, data): 830 """Saves attributes (data) of the specified name into the HDF5 group. 831 832 This method deals with an inherent problem of HDF5 file which is not 833 able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes. 834 835 Arguments: 836 group: A pointer to a HDF5 group. 837 name: A name of the attributes to save. 838 data: Attributes data to store. 839 840 Raises: 841 RuntimeError: If any single attribute is too large to be saved. 842 """ 843 # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` 844 # because in that case even chunking the array would not make the saving 845 # possible. 846 bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] 847 848 # Expecting this to never be true. 849 if bad_attributes: 850 raise RuntimeError('The following attributes cannot be saved to HDF5 ' 851 'file because they are larger than %d bytes: %s' % 852 (HDF5_OBJECT_HEADER_LIMIT, 853 ', '.join([x for x in bad_attributes]))) 854 855 data_npy = np.asarray(data) 856 857 num_chunks = 1 858 chunked_data = np.array_split(data_npy, num_chunks) 859 860 # This will never loop forever thanks to the test above. 861 while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): 862 num_chunks += 1 863 chunked_data = np.array_split(data_npy, num_chunks) 864 865 if num_chunks > 1: 866 for chunk_id, chunk_data in enumerate(chunked_data): 867 group.attrs['%s%d' % (name, chunk_id)] = chunk_data 868 else: 869 group.attrs[name] = data 870 871 872def load_attributes_from_hdf5_group(group, name): 873 """Loads attributes of the specified name from the HDF5 group. 874 875 This method deals with an inherent problem 876 of HDF5 file which is not able to store 877 data larger than HDF5_OBJECT_HEADER_LIMIT bytes. 878 879 Arguments: 880 group: A pointer to a HDF5 group. 881 name: A name of the attributes to load. 882 883 Returns: 884 data: Attributes data. 885 """ 886 if name in group.attrs: 887 data = [n.decode('utf8') for n in group.attrs[name]] 888 else: 889 data = [] 890 chunk_id = 0 891 while '%s%d' % (name, chunk_id) in group.attrs: 892 data.extend( 893 [n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]]) 894 chunk_id += 1 895 return data 896