1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Common util functions used by layers. 16""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import namedtuple 22from collections import OrderedDict 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import variables 28 29__all__ = ['collect_named_outputs', 30 'constant_value', 31 'static_cond', 32 'smart_cond', 33 'get_variable_collections', 34 'two_element_tuple', 35 'n_positive_integers', 36 'channel_dimension', 37 'last_dimension'] 38 39NamedOutputs = namedtuple('NamedOutputs', ['name', 'outputs']) 40 41 42def collect_named_outputs(collections, alias, outputs): 43 """Add `Tensor` outputs tagged with alias to collections. 44 45 It is useful to collect end-points or tags for summaries. Example of usage: 46 47 logits = collect_named_outputs('end_points', 'inception_v3/logits', logits) 48 assert 'inception_v3/logits' in logits.aliases 49 50 Args: 51 collections: A collection or list of collections. If None skip collection. 52 alias: String to append to the list of aliases of outputs, for example, 53 'inception_v3/conv1'. 54 outputs: Tensor, an output tensor to collect 55 56 Returns: 57 The outputs Tensor to allow inline call. 58 """ 59 if collections: 60 append_tensor_alias(outputs, alias) 61 ops.add_to_collections(collections, outputs) 62 return outputs 63 64 65def append_tensor_alias(tensor, alias): 66 """Append an alias to the list of aliases of the tensor. 67 68 Args: 69 tensor: A `Tensor`. 70 alias: String, to add to the list of aliases of the tensor. 71 72 Returns: 73 The tensor with a new alias appended to its list of aliases. 74 """ 75 # Remove ending '/' if present. 76 if alias[-1] == '/': 77 alias = alias[:-1] 78 if hasattr(tensor, 'aliases'): 79 tensor.aliases.append(alias) 80 else: 81 tensor.aliases = [alias] 82 return tensor 83 84 85def gather_tensors_aliases(tensors): 86 """Given a list of tensors, gather their aliases. 87 88 Args: 89 tensors: A list of `Tensors`. 90 91 Returns: 92 A list of strings with the aliases of all tensors. 93 """ 94 aliases = [] 95 for tensor in tensors: 96 aliases += get_tensor_aliases(tensor) 97 return aliases 98 99 100def get_tensor_aliases(tensor): 101 """Get a list with the aliases of the input tensor. 102 103 If the tensor does not have any alias, it would default to its its op.name or 104 its name. 105 106 Args: 107 tensor: A `Tensor`. 108 109 Returns: 110 A list of strings with the aliases of the tensor. 111 """ 112 if hasattr(tensor, 'aliases'): 113 aliases = tensor.aliases 114 else: 115 if tensor.name[-2:] == ':0': 116 # Use op.name for tensor ending in :0 117 aliases = [tensor.op.name] 118 else: 119 aliases = [tensor.name] 120 return aliases 121 122 123def convert_collection_to_dict(collection, clear_collection=False): 124 """Returns an OrderedDict of Tensors with their aliases as keys. 125 126 Args: 127 collection: A collection. 128 clear_collection: When True, it clears the collection after converting to 129 OrderedDict. 130 131 Returns: 132 An OrderedDict of {alias: tensor} 133 """ 134 output = OrderedDict((alias, tensor) 135 for tensor in ops.get_collection(collection) 136 for alias in get_tensor_aliases(tensor)) 137 if clear_collection: 138 ops.get_default_graph().clear_collection(collection) 139 return output 140 141 142def constant_value(value_or_tensor_or_var, dtype=None): 143 """Returns value if value_or_tensor_or_var has a constant value. 144 145 Args: 146 value_or_tensor_or_var: A value, a `Tensor` or a `Variable`. 147 dtype: Optional `tf.dtype`, if set it would check it has the right 148 dtype. 149 150 Returns: 151 The constant value or None if it not constant. 152 153 Raises: 154 ValueError: if value_or_tensor_or_var is None or the tensor_variable has the 155 wrong dtype. 156 """ 157 if value_or_tensor_or_var is None: 158 raise ValueError('value_or_tensor_or_var cannot be None') 159 value = value_or_tensor_or_var 160 if isinstance(value_or_tensor_or_var, (ops.Tensor, variables.Variable)): 161 if dtype and value_or_tensor_or_var.dtype != dtype: 162 raise ValueError('It has the wrong type %s instead of %s' % ( 163 value_or_tensor_or_var.dtype, dtype)) 164 if isinstance(value_or_tensor_or_var, variables.Variable): 165 value = None 166 else: 167 value = tensor_util.constant_value(value_or_tensor_or_var) 168 return value 169 170 171def static_cond(pred, fn1, fn2): 172 """Return either fn1() or fn2() based on the boolean value of `pred`. 173 174 Same signature as `control_flow_ops.cond()` but requires pred to be a bool. 175 176 Args: 177 pred: A value determining whether to return the result of `fn1` or `fn2`. 178 fn1: The callable to be performed if pred is true. 179 fn2: The callable to be performed if pred is false. 180 181 Returns: 182 Tensors returned by the call to either `fn1` or `fn2`. 183 184 Raises: 185 TypeError: if `fn1` or `fn2` is not callable. 186 """ 187 if not callable(fn1): 188 raise TypeError('fn1 must be callable.') 189 if not callable(fn2): 190 raise TypeError('fn2 must be callable.') 191 if pred: 192 return fn1() 193 else: 194 return fn2() 195 196 197def smart_cond(pred, fn1, fn2, name=None): 198 """Return either fn1() or fn2() based on the boolean predicate/value `pred`. 199 200 If `pred` is bool or has a constant value it would use `static_cond`, 201 otherwise it would use `tf.cond`. 202 203 Args: 204 pred: A scalar determining whether to return the result of `fn1` or `fn2`. 205 fn1: The callable to be performed if pred is true. 206 fn2: The callable to be performed if pred is false. 207 name: Optional name prefix when using tf.cond 208 Returns: 209 Tensors returned by the call to either `fn1` or `fn2`. 210 """ 211 pred_value = constant_value(pred) 212 if pred_value is not None: 213 # Use static_cond if pred has a constant value. 214 return static_cond(pred_value, fn1, fn2) 215 else: 216 # Use dynamic cond otherwise. 217 return control_flow_ops.cond(pred, fn1, fn2, name) 218 219 220def get_variable_collections(variables_collections, name): 221 if isinstance(variables_collections, dict): 222 variable_collections = variables_collections.get(name, None) 223 else: 224 variable_collections = variables_collections 225 return variable_collections 226 227 228def _get_dimension(shape, dim, min_rank=1): 229 """Returns the `dim` dimension of `shape`, while checking it has `min_rank`. 230 231 Args: 232 shape: A `TensorShape`. 233 dim: Integer, which dimension to return. 234 min_rank: Integer, minimum rank of shape. 235 236 Returns: 237 The value of the `dim` dimension. 238 239 Raises: 240 ValueError: if inputs don't have at least min_rank dimensions, or if the 241 first dimension value is not defined. 242 """ 243 dims = shape.dims 244 if dims is None: 245 raise ValueError('dims of shape must be known but is None') 246 if len(dims) < min_rank: 247 raise ValueError('rank of shape must be at least %d not: %d' % (min_rank, 248 len(dims))) 249 value = dims[dim].value 250 if value is None: 251 raise ValueError( 252 'dimension %d of shape must be known but is None: %s' % (dim, shape)) 253 return value 254 255 256def channel_dimension(shape, data_format, min_rank=1): 257 """Returns the channel dimension of shape, while checking it has min_rank. 258 259 Args: 260 shape: A `TensorShape`. 261 data_format: `channels_first` or `channels_last`. 262 min_rank: Integer, minimum rank of shape. 263 264 Returns: 265 The value of the first dimension. 266 267 Raises: 268 ValueError: if inputs don't have at least min_rank dimensions, or if the 269 first dimension value is not defined. 270 """ 271 return _get_dimension(shape, 1 if data_format == 'channels_first' else -1, 272 min_rank=min_rank) 273 274 275def last_dimension(shape, min_rank=1): 276 """Returns the last dimension of shape while checking it has min_rank. 277 278 Args: 279 shape: A `TensorShape`. 280 min_rank: Integer, minimum rank of shape. 281 282 Returns: 283 The value of the last dimension. 284 285 Raises: 286 ValueError: if inputs don't have at least min_rank dimensions, or if the 287 last dimension value is not defined. 288 """ 289 return _get_dimension(shape, -1, min_rank=min_rank) 290 291 292def two_element_tuple(int_or_tuple): 293 """Converts `int_or_tuple` to height, width. 294 295 Several of the functions that follow accept arguments as either 296 a tuple of 2 integers or a single integer. A single integer 297 indicates that the 2 values of the tuple are the same. 298 299 This functions normalizes the input value by always returning a tuple. 300 301 Args: 302 int_or_tuple: A list of 2 ints, a single int or a `TensorShape`. 303 304 Returns: 305 A tuple with 2 values. 306 307 Raises: 308 ValueError: If `int_or_tuple` it not well formed. 309 """ 310 if isinstance(int_or_tuple, (list, tuple)): 311 if len(int_or_tuple) != 2: 312 raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple) 313 return int(int_or_tuple[0]), int(int_or_tuple[1]) 314 if isinstance(int_or_tuple, int): 315 return int(int_or_tuple), int(int_or_tuple) 316 if isinstance(int_or_tuple, tensor_shape.TensorShape): 317 if len(int_or_tuple) == 2: 318 return int_or_tuple[0], int_or_tuple[1] 319 raise ValueError('Must be an int, a list with 2 elements or a TensorShape of ' 320 'length 2') 321 322 323def n_positive_integers(n, value): 324 """Converts `value` to a sequence of `n` positive integers. 325 326 `value` may be either be a sequence of values convertible to `int`, or a 327 single value convertible to `int`, in which case the resulting integer is 328 duplicated `n` times. It may also be a TensorShape of rank `n`. 329 330 Args: 331 n: Length of sequence to return. 332 value: Either a single value convertible to a positive `int` or an 333 `n`-element sequence of values convertible to a positive `int`. 334 335 Returns: 336 A tuple of `n` positive integers. 337 338 Raises: 339 TypeError: If `n` is not convertible to an integer. 340 ValueError: If `n` or `value` are invalid. 341 """ 342 343 n_orig = n 344 n = int(n) 345 if n < 1 or n != n_orig: 346 raise ValueError('n must be a positive integer') 347 348 try: 349 value = int(value) 350 except (TypeError, ValueError): 351 sequence_len = len(value) 352 if sequence_len != n: 353 raise ValueError( 354 'Expected sequence of %d positive integers, but received %r' % 355 (n, value)) 356 try: 357 values = tuple(int(x) for x in value) 358 except: 359 raise ValueError( 360 'Expected sequence of %d positive integers, but received %r' % 361 (n, value)) 362 for x in values: 363 if x < 1: 364 raise ValueError('expected positive integer, but received %d' % x) 365 return values 366 367 if value < 1: 368 raise ValueError('expected positive integer, but received %d' % value) 369 return (value,) * n 370