1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tensorflow layers with added variables for parameter masking. 16 17Branched from tensorflow/contrib/layers/python/layers/layers.py 18""" 19# pylint: disable=missing-docstring 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import six 25 26from tensorflow.contrib.framework.python.ops import add_arg_scope 27from tensorflow.contrib.framework.python.ops import variables 28from tensorflow.contrib.layers.python.layers import initializers 29from tensorflow.contrib.layers.python.layers import utils 30from tensorflow.contrib.model_pruning.python.layers import core_layers as core 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import init_ops 33from tensorflow.python.ops import nn 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.ops import variables as tf_variables 36 37 38def _model_variable_getter(getter, 39 name, 40 shape=None, 41 dtype=None, 42 initializer=None, 43 regularizer=None, 44 trainable=True, 45 collections=None, 46 caching_device=None, 47 partitioner=None, 48 rename=None, 49 use_resource=None, 50 **_): 51 """Getter that uses model_variable for compatibility with core layers.""" 52 short_name = name.split('/')[-1] 53 if rename and short_name in rename: 54 name_components = name.split('/') 55 name_components[-1] = rename[short_name] 56 name = '/'.join(name_components) 57 return variables.model_variable( 58 name, 59 shape=shape, 60 dtype=dtype, 61 initializer=initializer, 62 regularizer=regularizer, 63 collections=collections, 64 trainable=trainable, 65 caching_device=caching_device, 66 partitioner=partitioner, 67 custom_getter=getter, 68 use_resource=use_resource) 69 70 71def _build_variable_getter(rename=None): 72 """Build a model variable getter that respects scope getter and renames.""" 73 74 # VariableScope will nest the getters 75 def layer_variable_getter(getter, *args, **kwargs): 76 kwargs['rename'] = rename 77 return _model_variable_getter(getter, *args, **kwargs) 78 79 return layer_variable_getter 80 81 82def _add_variable_to_collections(variable, collections_set, collections_name): 83 """Adds variable (or all its parts) to all collections with that name.""" 84 collections = utils.get_variable_collections(collections_set, 85 collections_name) or [] 86 variables_list = [variable] 87 if isinstance(variable, tf_variables.PartitionedVariable): 88 variables_list = [v for v in variable] 89 for collection in collections: 90 for var in variables_list: 91 if var not in ops.get_collection(collection): 92 ops.add_to_collection(collection, var) 93 94 95@add_arg_scope 96def masked_convolution(inputs, 97 num_outputs, 98 kernel_size, 99 stride=1, 100 padding='SAME', 101 data_format=None, 102 rate=1, 103 activation_fn=nn.relu, 104 normalizer_fn=None, 105 normalizer_params=None, 106 weights_initializer=initializers.xavier_initializer(), 107 weights_regularizer=None, 108 biases_initializer=init_ops.zeros_initializer(), 109 biases_regularizer=None, 110 reuse=None, 111 variables_collections=None, 112 outputs_collections=None, 113 trainable=True, 114 scope=None): 115 """Adds an 2D convolution followed by an optional batch_norm layer. 116 The layer creates a mask variable on top of the weight variable. The input to 117 the convolution operation is the elementwise multiplication of the mask 118 variable and the weigh 119 120 It is required that 1 <= N <= 3. 121 122 `convolution` creates a variable called `weights`, representing the 123 convolutional kernel, that is convolved (actually cross-correlated) with the 124 `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is 125 provided (such as `batch_norm`), it is then applied. Otherwise, if 126 `normalizer_fn` is None and a `biases_initializer` is provided then a `biases` 127 variable would be created and added the activations. Finally, if 128 `activation_fn` is not `None`, it is applied to the activations as well. 129 130 Performs atrous convolution with input stride/dilation rate equal to `rate` 131 if a value > 1 for any dimension of `rate` is specified. In this case 132 `stride` values != 1 are not supported. 133 134 Args: 135 inputs: A Tensor of rank N+2 of shape 136 `[batch_size] + input_spatial_shape + [in_channels]` if data_format does 137 not start with "NC" (default), or 138 `[batch_size, in_channels] + input_spatial_shape` if data_format starts 139 with "NC". 140 num_outputs: Integer, the number of output filters. 141 kernel_size: A sequence of N positive integers specifying the spatial 142 dimensions of the filters. Can be a single integer to specify the same 143 value for all spatial dimensions. 144 stride: A sequence of N positive integers specifying the stride at which to 145 compute output. Can be a single integer to specify the same value for all 146 spatial dimensions. Specifying any `stride` value != 1 is incompatible 147 with specifying any `rate` value != 1. 148 padding: One of `"VALID"` or `"SAME"`. 149 data_format: A string or None. Specifies whether the channel dimension of 150 the `input` and output is the last dimension (default, or if `data_format` 151 does not start with "NC"), or the second dimension (if `data_format` 152 starts with "NC"). For N=1, the valid values are "NWC" (default) and 153 "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". 154 For N=3, the valid values are "NDHWC" (default) and "NCDHW". 155 rate: A sequence of N positive integers specifying the dilation rate to use 156 for atrous convolution. Can be a single integer to specify the same 157 value for all spatial dimensions. Specifying any `rate` value != 1 is 158 incompatible with specifying any `stride` value != 1. 159 activation_fn: Activation function. The default value is a ReLU function. 160 Explicitly set it to None to skip it and maintain a linear activation. 161 normalizer_fn: Normalization function to use instead of `biases`. If 162 `normalizer_fn` is provided then `biases_initializer` and 163 `biases_regularizer` are ignored and `biases` are not created nor added. 164 default set to None for no normalizer function 165 normalizer_params: Normalization function parameters. 166 weights_initializer: An initializer for the weights. 167 weights_regularizer: Optional regularizer for the weights. 168 biases_initializer: An initializer for the biases. If None skip biases. 169 biases_regularizer: Optional regularizer for the biases. 170 reuse: Whether or not the layer and its variables should be reused. To be 171 able to reuse the layer scope must be given. 172 variables_collections: Optional list of collections for all the variables or 173 a dictionary containing a different list of collection per variable. 174 outputs_collections: Collection to add the outputs. 175 trainable: If `True` also add variables to the graph collection 176 `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 177 scope: Optional scope for `variable_scope`. 178 179 Returns: 180 A tensor representing the output of the operation. 181 182 Raises: 183 ValueError: If `data_format` is invalid. 184 ValueError: Both 'rate' and `stride` are not uniformly 1. 185 """ 186 if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']: 187 raise ValueError('Invalid data_format: %r' % (data_format,)) 188 189 layer_variable_getter = _build_variable_getter({ 190 'bias': 'biases', 191 'kernel': 'weights' 192 }) 193 194 with variable_scope.variable_scope( 195 scope, 'Conv', [inputs], reuse=reuse, 196 custom_getter=layer_variable_getter) as sc: 197 inputs = ops.convert_to_tensor(inputs) 198 input_rank = inputs.get_shape().ndims 199 200 if input_rank == 3: 201 raise ValueError('Sparse Convolution not supported for input with rank', 202 input_rank) 203 elif input_rank == 4: 204 layer_class = core.MaskedConv2D 205 elif input_rank == 5: 206 raise ValueError('Sparse Convolution not supported for input with rank', 207 input_rank) 208 else: 209 raise ValueError('Sparse Convolution not supported for input with rank', 210 input_rank) 211 212 if data_format is None or data_format == 'NHWC': 213 df = 'channels_last' 214 elif data_format == 'NCHW': 215 df = 'channels_first' 216 else: 217 raise ValueError('Unsupported data format', data_format) 218 219 layer = layer_class( 220 filters=num_outputs, 221 kernel_size=kernel_size, 222 strides=stride, 223 padding=padding, 224 data_format=df, 225 dilation_rate=rate, 226 activation=None, 227 use_bias=not normalizer_fn and biases_initializer, 228 kernel_initializer=weights_initializer, 229 bias_initializer=biases_initializer, 230 kernel_regularizer=weights_regularizer, 231 bias_regularizer=biases_regularizer, 232 activity_regularizer=None, 233 trainable=trainable, 234 name=sc.name, 235 dtype=inputs.dtype.base_dtype, 236 _scope=sc, 237 _reuse=reuse) 238 outputs = layer.apply(inputs) 239 240 # Add variables to collections. 241 _add_variable_to_collections(layer.kernel, variables_collections, 'weights') 242 if layer.use_bias: 243 _add_variable_to_collections(layer.bias, variables_collections, 'biases') 244 245 if normalizer_fn is not None: 246 normalizer_params = normalizer_params or {} 247 outputs = normalizer_fn(outputs, **normalizer_params) 248 249 if activation_fn is not None: 250 outputs = activation_fn(outputs) 251 return utils.collect_named_outputs(outputs_collections, 252 sc.original_name_scope, outputs) 253 254 255masked_conv2d = masked_convolution 256 257 258@add_arg_scope 259def masked_fully_connected( 260 inputs, 261 num_outputs, 262 activation_fn=nn.relu, 263 normalizer_fn=None, 264 normalizer_params=None, 265 weights_initializer=initializers.xavier_initializer(), 266 weights_regularizer=None, 267 biases_initializer=init_ops.zeros_initializer(), 268 biases_regularizer=None, 269 reuse=None, 270 variables_collections=None, 271 outputs_collections=None, 272 trainable=True, 273 scope=None): 274 """Adds a sparse fully connected layer. The weight matrix is masked. 275 276 `fully_connected` creates a variable called `weights`, representing a fully 277 connected weight matrix, which is multiplied by the `inputs` to produce a 278 `Tensor` of hidden units. If a `normalizer_fn` is provided (such as 279 `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is 280 None and a `biases_initializer` is provided then a `biases` variable would be 281 created and added the hidden units. Finally, if `activation_fn` is not `None`, 282 it is applied to the hidden units as well. 283 284 Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened 285 prior to the initial matrix multiply by `weights`. 286 287 Args: 288 inputs: A tensor of at least rank 2 and static value for the last dimension; 289 i.e. `[batch_size, depth]`, `[None, None, None, channels]`. 290 num_outputs: Integer or long, the number of output units in the layer. 291 activation_fn: Activation function. The default value is a ReLU function. 292 Explicitly set it to None to skip it and maintain a linear activation. 293 normalizer_fn: Normalization function to use instead of `biases`. If 294 `normalizer_fn` is provided then `biases_initializer` and 295 `biases_regularizer` are ignored and `biases` are not created nor added. 296 default set to None for no normalizer function 297 normalizer_params: Normalization function parameters. 298 weights_initializer: An initializer for the weights. 299 weights_regularizer: Optional regularizer for the weights. 300 biases_initializer: An initializer for the biases. If None skip biases. 301 biases_regularizer: Optional regularizer for the biases. 302 reuse: Whether or not the layer and its variables should be reused. To be 303 able to reuse the layer scope must be given. 304 variables_collections: Optional list of collections for all the variables or 305 a dictionary containing a different list of collections per variable. 306 outputs_collections: Collection to add the outputs. 307 trainable: If `True` also add variables to the graph collection 308 `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). 309 scope: Optional scope for variable_scope. 310 311 Returns: 312 The tensor variable representing the result of the series of operations. 313 314 Raises: 315 ValueError: If x has rank less than 2 or if its last dimension is not set. 316 """ 317 if not isinstance(num_outputs, six.integer_types): 318 raise ValueError('num_outputs should be int or long, got %s.' % 319 (num_outputs,)) 320 321 layer_variable_getter = _build_variable_getter({ 322 'bias': 'biases', 323 'kernel': 'weights' 324 }) 325 326 with variable_scope.variable_scope( 327 scope, 328 'fully_connected', [inputs], 329 reuse=reuse, 330 custom_getter=layer_variable_getter) as sc: 331 inputs = ops.convert_to_tensor(inputs) 332 layer = core.MaskedFullyConnected( 333 units=num_outputs, 334 activation=None, 335 use_bias=not normalizer_fn and biases_initializer, 336 kernel_initializer=weights_initializer, 337 bias_initializer=biases_initializer, 338 kernel_regularizer=weights_regularizer, 339 bias_regularizer=biases_regularizer, 340 activity_regularizer=None, 341 trainable=trainable, 342 name=sc.name, 343 dtype=inputs.dtype.base_dtype, 344 _scope=sc, 345 _reuse=reuse) 346 outputs = layer.apply(inputs) 347 348 # Add variables to collections. 349 _add_variable_to_collections(layer.kernel, variables_collections, 'weights') 350 if layer.bias is not None: 351 _add_variable_to_collections(layer.bias, variables_collections, 'biases') 352 353 # Apply normalizer function / layer. 354 if normalizer_fn is not None: 355 if not normalizer_params: 356 normalizer_params = {} 357 outputs = normalizer_fn(outputs, **normalizer_params) 358 359 if activation_fn is not None: 360 outputs = activation_fn(outputs) 361 362 return utils.collect_named_outputs(outputs_collections, 363 sc.original_name_scope, outputs) 364