• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tensorflow layers with added variables for parameter masking.
16
17Branched from tensorflow/contrib/layers/python/layers/layers.py
18"""
19# pylint: disable=missing-docstring
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import six
25
26from tensorflow.contrib.framework.python.ops import add_arg_scope
27from tensorflow.contrib.framework.python.ops import variables
28from tensorflow.contrib.layers.python.layers import initializers
29from tensorflow.contrib.layers.python.layers import utils
30from tensorflow.contrib.model_pruning.python.layers import core_layers as core
31from tensorflow.python.framework import ops
32from tensorflow.python.ops import init_ops
33from tensorflow.python.ops import nn
34from tensorflow.python.ops import variable_scope
35from tensorflow.python.ops import variables as tf_variables
36
37
38def _model_variable_getter(getter,
39                           name,
40                           shape=None,
41                           dtype=None,
42                           initializer=None,
43                           regularizer=None,
44                           trainable=True,
45                           collections=None,
46                           caching_device=None,
47                           partitioner=None,
48                           rename=None,
49                           use_resource=None,
50                           **_):
51  """Getter that uses model_variable for compatibility with core layers."""
52  short_name = name.split('/')[-1]
53  if rename and short_name in rename:
54    name_components = name.split('/')
55    name_components[-1] = rename[short_name]
56    name = '/'.join(name_components)
57  return variables.model_variable(
58      name,
59      shape=shape,
60      dtype=dtype,
61      initializer=initializer,
62      regularizer=regularizer,
63      collections=collections,
64      trainable=trainable,
65      caching_device=caching_device,
66      partitioner=partitioner,
67      custom_getter=getter,
68      use_resource=use_resource)
69
70
71def _build_variable_getter(rename=None):
72  """Build a model variable getter that respects scope getter and renames."""
73
74  # VariableScope will nest the getters
75  def layer_variable_getter(getter, *args, **kwargs):
76    kwargs['rename'] = rename
77    return _model_variable_getter(getter, *args, **kwargs)
78
79  return layer_variable_getter
80
81
82def _add_variable_to_collections(variable, collections_set, collections_name):
83  """Adds variable (or all its parts) to all collections with that name."""
84  collections = utils.get_variable_collections(collections_set,
85                                               collections_name) or []
86  variables_list = [variable]
87  if isinstance(variable, tf_variables.PartitionedVariable):
88    variables_list = [v for v in variable]
89  for collection in collections:
90    for var in variables_list:
91      if var not in ops.get_collection(collection):
92        ops.add_to_collection(collection, var)
93
94
95@add_arg_scope
96def masked_convolution(inputs,
97                       num_outputs,
98                       kernel_size,
99                       stride=1,
100                       padding='SAME',
101                       data_format=None,
102                       rate=1,
103                       activation_fn=nn.relu,
104                       normalizer_fn=None,
105                       normalizer_params=None,
106                       weights_initializer=initializers.xavier_initializer(),
107                       weights_regularizer=None,
108                       biases_initializer=init_ops.zeros_initializer(),
109                       biases_regularizer=None,
110                       reuse=None,
111                       variables_collections=None,
112                       outputs_collections=None,
113                       trainable=True,
114                       scope=None):
115  """Adds an 2D convolution followed by an optional batch_norm layer.
116  The layer creates a mask variable on top of the weight variable. The input to
117  the convolution operation is the elementwise multiplication of the mask
118  variable and the weigh
119
120  It is required that 1 <= N <= 3.
121
122  `convolution` creates a variable called `weights`, representing the
123  convolutional kernel, that is convolved (actually cross-correlated) with the
124  `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is
125  provided (such as `batch_norm`), it is then applied. Otherwise, if
126  `normalizer_fn` is None and a `biases_initializer` is provided then a `biases`
127  variable would be created and added the activations. Finally, if
128  `activation_fn` is not `None`, it is applied to the activations as well.
129
130  Performs atrous convolution with input stride/dilation rate equal to `rate`
131  if a value > 1 for any dimension of `rate` is specified.  In this case
132  `stride` values != 1 are not supported.
133
134  Args:
135    inputs: A Tensor of rank N+2 of shape
136      `[batch_size] + input_spatial_shape + [in_channels]` if data_format does
137      not start with "NC" (default), or
138      `[batch_size, in_channels] + input_spatial_shape` if data_format starts
139      with "NC".
140    num_outputs: Integer, the number of output filters.
141    kernel_size: A sequence of N positive integers specifying the spatial
142      dimensions of the filters.  Can be a single integer to specify the same
143      value for all spatial dimensions.
144    stride: A sequence of N positive integers specifying the stride at which to
145      compute output.  Can be a single integer to specify the same value for all
146      spatial dimensions.  Specifying any `stride` value != 1 is incompatible
147      with specifying any `rate` value != 1.
148    padding: One of `"VALID"` or `"SAME"`.
149    data_format: A string or None.  Specifies whether the channel dimension of
150      the `input` and output is the last dimension (default, or if `data_format`
151      does not start with "NC"), or the second dimension (if `data_format`
152      starts with "NC").  For N=1, the valid values are "NWC" (default) and
153      "NCW".  For N=2, the valid values are "NHWC" (default) and "NCHW".
154      For N=3, the valid values are "NDHWC" (default) and "NCDHW".
155    rate: A sequence of N positive integers specifying the dilation rate to use
156      for atrous convolution.  Can be a single integer to specify the same
157      value for all spatial dimensions.  Specifying any `rate` value != 1 is
158      incompatible with specifying any `stride` value != 1.
159    activation_fn: Activation function. The default value is a ReLU function.
160      Explicitly set it to None to skip it and maintain a linear activation.
161    normalizer_fn: Normalization function to use instead of `biases`. If
162      `normalizer_fn` is provided then `biases_initializer` and
163      `biases_regularizer` are ignored and `biases` are not created nor added.
164      default set to None for no normalizer function
165    normalizer_params: Normalization function parameters.
166    weights_initializer: An initializer for the weights.
167    weights_regularizer: Optional regularizer for the weights.
168    biases_initializer: An initializer for the biases. If None skip biases.
169    biases_regularizer: Optional regularizer for the biases.
170    reuse: Whether or not the layer and its variables should be reused. To be
171      able to reuse the layer scope must be given.
172    variables_collections: Optional list of collections for all the variables or
173      a dictionary containing a different list of collection per variable.
174    outputs_collections: Collection to add the outputs.
175    trainable: If `True` also add variables to the graph collection
176      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
177    scope: Optional scope for `variable_scope`.
178
179  Returns:
180    A tensor representing the output of the operation.
181
182  Raises:
183    ValueError: If `data_format` is invalid.
184    ValueError: Both 'rate' and `stride` are not uniformly 1.
185  """
186  if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']:
187    raise ValueError('Invalid data_format: %r' % (data_format,))
188
189  layer_variable_getter = _build_variable_getter({
190      'bias': 'biases',
191      'kernel': 'weights'
192  })
193
194  with variable_scope.variable_scope(
195      scope, 'Conv', [inputs], reuse=reuse,
196      custom_getter=layer_variable_getter) as sc:
197    inputs = ops.convert_to_tensor(inputs)
198    input_rank = inputs.get_shape().ndims
199
200    if input_rank == 3:
201      raise ValueError('Sparse Convolution not supported for input with rank',
202                       input_rank)
203    elif input_rank == 4:
204      layer_class = core.MaskedConv2D
205    elif input_rank == 5:
206      raise ValueError('Sparse Convolution not supported for input with rank',
207                       input_rank)
208    else:
209      raise ValueError('Sparse Convolution not supported for input with rank',
210                       input_rank)
211
212    if data_format is None or data_format == 'NHWC':
213      df = 'channels_last'
214    elif data_format == 'NCHW':
215      df = 'channels_first'
216    else:
217      raise ValueError('Unsupported data format', data_format)
218
219    layer = layer_class(
220        filters=num_outputs,
221        kernel_size=kernel_size,
222        strides=stride,
223        padding=padding,
224        data_format=df,
225        dilation_rate=rate,
226        activation=None,
227        use_bias=not normalizer_fn and biases_initializer,
228        kernel_initializer=weights_initializer,
229        bias_initializer=biases_initializer,
230        kernel_regularizer=weights_regularizer,
231        bias_regularizer=biases_regularizer,
232        activity_regularizer=None,
233        trainable=trainable,
234        name=sc.name,
235        dtype=inputs.dtype.base_dtype,
236        _scope=sc,
237        _reuse=reuse)
238    outputs = layer.apply(inputs)
239
240    # Add variables to collections.
241    _add_variable_to_collections(layer.kernel, variables_collections, 'weights')
242    if layer.use_bias:
243      _add_variable_to_collections(layer.bias, variables_collections, 'biases')
244
245    if normalizer_fn is not None:
246      normalizer_params = normalizer_params or {}
247      outputs = normalizer_fn(outputs, **normalizer_params)
248
249    if activation_fn is not None:
250      outputs = activation_fn(outputs)
251    return utils.collect_named_outputs(outputs_collections,
252                                       sc.original_name_scope, outputs)
253
254
255masked_conv2d = masked_convolution
256
257
258@add_arg_scope
259def masked_fully_connected(
260    inputs,
261    num_outputs,
262    activation_fn=nn.relu,
263    normalizer_fn=None,
264    normalizer_params=None,
265    weights_initializer=initializers.xavier_initializer(),
266    weights_regularizer=None,
267    biases_initializer=init_ops.zeros_initializer(),
268    biases_regularizer=None,
269    reuse=None,
270    variables_collections=None,
271    outputs_collections=None,
272    trainable=True,
273    scope=None):
274  """Adds a sparse fully connected layer. The weight matrix is masked.
275
276  `fully_connected` creates a variable called `weights`, representing a fully
277  connected weight matrix, which is multiplied by the `inputs` to produce a
278  `Tensor` of hidden units. If a `normalizer_fn` is provided (such as
279  `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is
280  None and a `biases_initializer` is provided then a `biases` variable would be
281  created and added the hidden units. Finally, if `activation_fn` is not `None`,
282  it is applied to the hidden units as well.
283
284  Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened
285  prior to the initial matrix multiply by `weights`.
286
287  Args:
288    inputs: A tensor of at least rank 2 and static value for the last dimension;
289      i.e. `[batch_size, depth]`, `[None, None, None, channels]`.
290    num_outputs: Integer or long, the number of output units in the layer.
291    activation_fn: Activation function. The default value is a ReLU function.
292      Explicitly set it to None to skip it and maintain a linear activation.
293    normalizer_fn: Normalization function to use instead of `biases`. If
294      `normalizer_fn` is provided then `biases_initializer` and
295      `biases_regularizer` are ignored and `biases` are not created nor added.
296      default set to None for no normalizer function
297    normalizer_params: Normalization function parameters.
298    weights_initializer: An initializer for the weights.
299    weights_regularizer: Optional regularizer for the weights.
300    biases_initializer: An initializer for the biases. If None skip biases.
301    biases_regularizer: Optional regularizer for the biases.
302    reuse: Whether or not the layer and its variables should be reused. To be
303      able to reuse the layer scope must be given.
304    variables_collections: Optional list of collections for all the variables or
305      a dictionary containing a different list of collections per variable.
306    outputs_collections: Collection to add the outputs.
307    trainable: If `True` also add variables to the graph collection
308      `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
309    scope: Optional scope for variable_scope.
310
311  Returns:
312     The tensor variable representing the result of the series of operations.
313
314  Raises:
315    ValueError: If x has rank less than 2 or if its last dimension is not set.
316  """
317  if not isinstance(num_outputs, six.integer_types):
318    raise ValueError('num_outputs should be int or long, got %s.' %
319                     (num_outputs,))
320
321  layer_variable_getter = _build_variable_getter({
322      'bias': 'biases',
323      'kernel': 'weights'
324  })
325
326  with variable_scope.variable_scope(
327      scope,
328      'fully_connected', [inputs],
329      reuse=reuse,
330      custom_getter=layer_variable_getter) as sc:
331    inputs = ops.convert_to_tensor(inputs)
332    layer = core.MaskedFullyConnected(
333        units=num_outputs,
334        activation=None,
335        use_bias=not normalizer_fn and biases_initializer,
336        kernel_initializer=weights_initializer,
337        bias_initializer=biases_initializer,
338        kernel_regularizer=weights_regularizer,
339        bias_regularizer=biases_regularizer,
340        activity_regularizer=None,
341        trainable=trainable,
342        name=sc.name,
343        dtype=inputs.dtype.base_dtype,
344        _scope=sc,
345        _reuse=reuse)
346    outputs = layer.apply(inputs)
347
348    # Add variables to collections.
349    _add_variable_to_collections(layer.kernel, variables_collections, 'weights')
350    if layer.bias is not None:
351      _add_variable_to_collections(layer.bias, variables_collections, 'biases')
352
353    # Apply normalizer function / layer.
354    if normalizer_fn is not None:
355      if not normalizer_params:
356        normalizer_params = {}
357      outputs = normalizer_fn(outputs, **normalizer_params)
358
359    if activation_fn is not None:
360      outputs = activation_fn(outputs)
361
362    return utils.collect_named_outputs(outputs_collections,
363                                       sc.original_name_scope, outputs)
364