1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=invalid-name 16"""ResNet v2 models for Keras.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.keras.applications import imagenet_utils 22from tensorflow.python.keras.applications import resnet 23from tensorflow.python.util.tf_export import keras_export 24 25 26@keras_export('keras.applications.resnet_v2.ResNet50V2', 27 'keras.applications.ResNet50V2') 28def ResNet50V2(include_top=True, 29 weights='imagenet', 30 input_tensor=None, 31 input_shape=None, 32 pooling=None, 33 classes=1000): 34 """Instantiates the ResNet50V2 architecture.""" 35 def stack_fn(x): 36 x = resnet.stack2(x, 64, 3, name='conv2') 37 x = resnet.stack2(x, 128, 4, name='conv3') 38 x = resnet.stack2(x, 256, 6, name='conv4') 39 return resnet.stack2(x, 512, 3, stride1=1, name='conv5') 40 return resnet.ResNet(stack_fn, True, True, 'resnet50v2', include_top, weights, 41 input_tensor, input_shape, pooling, classes) 42 43 44@keras_export('keras.applications.resnet_v2.ResNet101V2', 45 'keras.applications.ResNet101V2') 46def ResNet101V2(include_top=True, 47 weights='imagenet', 48 input_tensor=None, 49 input_shape=None, 50 pooling=None, 51 classes=1000): 52 """Instantiates the ResNet101V2 architecture.""" 53 def stack_fn(x): 54 x = resnet.stack2(x, 64, 3, name='conv2') 55 x = resnet.stack2(x, 128, 4, name='conv3') 56 x = resnet.stack2(x, 256, 23, name='conv4') 57 return resnet.stack2(x, 512, 3, stride1=1, name='conv5') 58 return resnet.ResNet(stack_fn, True, True, 'resnet101v2', include_top, 59 weights, input_tensor, input_shape, pooling, classes) 60 61 62@keras_export('keras.applications.resnet_v2.ResNet152V2', 63 'keras.applications.ResNet152V2') 64def ResNet152V2(include_top=True, 65 weights='imagenet', 66 input_tensor=None, 67 input_shape=None, 68 pooling=None, 69 classes=1000): 70 """Instantiates the ResNet152V2 architecture.""" 71 def stack_fn(x): 72 x = resnet.stack2(x, 64, 3, name='conv2') 73 x = resnet.stack2(x, 128, 8, name='conv3') 74 x = resnet.stack2(x, 256, 36, name='conv4') 75 return resnet.stack2(x, 512, 3, stride1=1, name='conv5') 76 return resnet.ResNet(stack_fn, True, True, 'resnet152v2', include_top, 77 weights, input_tensor, input_shape, pooling, classes) 78 79 80@keras_export('keras.applications.resnet_v2.preprocess_input') 81def preprocess_input(x, data_format=None): 82 return imagenet_utils.preprocess_input( 83 x, data_format=data_format, mode='tf') 84 85 86@keras_export('keras.applications.resnet_v2.decode_predictions') 87def decode_predictions(preds, top=5): 88 return imagenet_utils.decode_predictions(preds, top=top) 89 90 91DOC = """ 92 93 Optionally loads weights pre-trained on ImageNet. 94 Note that the data format convention used by the model is 95 the one specified in your Keras config at `~/.keras/keras.json`. 96 97 Arguments: 98 include_top: whether to include the fully-connected 99 layer at the top of the network. 100 weights: one of `None` (random initialization), 101 'imagenet' (pre-training on ImageNet), 102 or the path to the weights file to be loaded. 103 input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) 104 to use as image input for the model. 105 input_shape: optional shape tuple, only to be specified 106 if `include_top` is False (otherwise the input shape 107 has to be `(224, 224, 3)` (with `'channels_last'` data format) 108 or `(3, 224, 224)` (with `'channels_first'` data format). 109 It should have exactly 3 inputs channels, 110 and width and height should be no smaller than 32. 111 E.g. `(200, 200, 3)` would be one valid value. 112 pooling: Optional pooling mode for feature extraction 113 when `include_top` is `False`. 114 - `None` means that the output of the model will be 115 the 4D tensor output of the 116 last convolutional block. 117 - `avg` means that global average pooling 118 will be applied to the output of the 119 last convolutional block, and thus 120 the output of the model will be a 2D tensor. 121 - `max` means that global max pooling will 122 be applied. 123 classes: optional number of classes to classify images 124 into, only to be specified if `include_top` is True, and 125 if no `weights` argument is specified. 126 127 Returns: 128 A Keras model instance. 129""" 130 131setattr(ResNet50V2, '__doc__', ResNet50V2.__doc__ + DOC) 132setattr(ResNet101V2, '__doc__', ResNet101V2.__doc__ + DOC) 133setattr(ResNet152V2, '__doc__', ResNet152V2.__doc__ + DOC) 134