1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Contains the TFExampleDecoder its associated helper classes. 16 17The TFExampleDecode is a DataDecoder used to decode TensorFlow Example protos. 18In order to do so each requested item must be paired with one or more Example 19features that are parsed to produce the Tensor-based manifestation of the item. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import abc 27 28import six 29 30from tensorflow.contrib.slim.python.slim.data import data_decoder 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import sparse_tensor 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import map_fn 36from tensorflow.python.ops import image_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import parsing_ops 39from tensorflow.python.ops import sparse_ops 40 41 42@six.add_metaclass(abc.ABCMeta) 43class ItemHandler(object): 44 """Specifies the item-to-Features mapping for tf.parse_example. 45 46 An ItemHandler both specifies a list of Features used for parsing an Example 47 proto as well as a function that post-processes the results of Example 48 parsing. 49 """ 50 51 def __init__(self, keys): 52 """Constructs the handler with the name of the tf.Feature keys to use. 53 54 See third_party/tensorflow/core/example/feature.proto 55 56 Args: 57 keys: the name of the TensorFlow Example Feature. 58 """ 59 if not isinstance(keys, (tuple, list)): 60 keys = [keys] 61 self._keys = keys 62 63 @property 64 def keys(self): 65 return self._keys 66 67 @abc.abstractmethod 68 def tensors_to_item(self, keys_to_tensors): 69 """Maps the given dictionary of tensors to the requested item. 70 71 Args: 72 keys_to_tensors: a mapping of TF-Example keys to parsed tensors. 73 74 Returns: 75 the final tensor representing the item being handled. 76 """ 77 pass 78 79 80class ItemHandlerCallback(ItemHandler): 81 """An ItemHandler that converts the parsed tensors via a given function. 82 83 Unlike other ItemHandlers, the ItemHandlerCallback resolves its item via 84 a callback function rather than using prespecified behavior. 85 """ 86 87 def __init__(self, keys, func): 88 """Initializes the ItemHandler. 89 90 Args: 91 keys: a list of TF-Example keys. 92 func: a function that takes as an argument a dictionary from `keys` to 93 parsed Tensors. 94 """ 95 super(ItemHandlerCallback, self).__init__(keys) 96 self._func = func 97 98 def tensors_to_item(self, keys_to_tensors): 99 return self._func(keys_to_tensors) 100 101 102class BoundingBox(ItemHandler): 103 """An ItemHandler that concatenates a set of parsed Tensors to Bounding Boxes. 104 """ 105 106 def __init__(self, keys=None, prefix=''): 107 """Initialize the bounding box handler. 108 109 Args: 110 keys: A list of four key names representing the ymin, xmin, ymax, mmax 111 prefix: An optional prefix for each of the bounding box keys. 112 If provided, `prefix` is appended to each key in `keys`. 113 114 Raises: 115 ValueError: if keys is not `None` and also not a list of exactly 4 keys 116 """ 117 if keys is None: 118 keys = ['ymin', 'xmin', 'ymax', 'xmax'] 119 elif len(keys) != 4: 120 raise ValueError('BoundingBox expects 4 keys but got {}'.format( 121 len(keys))) 122 self._prefix = prefix 123 self._keys = keys 124 self._full_keys = [prefix + k for k in keys] 125 super(BoundingBox, self).__init__(self._full_keys) 126 127 def tensors_to_item(self, keys_to_tensors): 128 """Maps the given dictionary of tensors to a concatenated list of bboxes. 129 130 Args: 131 keys_to_tensors: a mapping of TF-Example keys to parsed tensors. 132 133 Returns: 134 [num_boxes, 4] tensor of bounding box coordinates, 135 i.e. 1 bounding box per row, in order [y_min, x_min, y_max, x_max]. 136 """ 137 sides = [] 138 for key in self._full_keys: 139 side = keys_to_tensors[key] 140 if isinstance(side, sparse_tensor.SparseTensor): 141 side = side.values 142 side = array_ops.expand_dims(side, 0) 143 sides.append(side) 144 145 bounding_box = array_ops.concat(sides, 0) 146 return array_ops.transpose(bounding_box) 147 148 149class Tensor(ItemHandler): 150 """An ItemHandler that returns a parsed Tensor.""" 151 152 def __init__(self, tensor_key, shape_keys=None, shape=None, default_value=0): 153 """Initializes the Tensor handler. 154 155 Tensors are, by default, returned without any reshaping. However, there are 156 two mechanisms which allow reshaping to occur at load time. If `shape_keys` 157 is provided, both the `Tensor` corresponding to `tensor_key` and 158 `shape_keys` is loaded and the former `Tensor` is reshaped with the values 159 of the latter. Alternatively, if a fixed `shape` is provided, the `Tensor` 160 corresponding to `tensor_key` is loaded and reshape appropriately. 161 If neither `shape_keys` nor `shape` are provided, the `Tensor` will be 162 returned without any reshaping. 163 164 Args: 165 tensor_key: the name of the `TFExample` feature to read the tensor from. 166 shape_keys: Optional name or list of names of the TF-Example feature in 167 which the tensor shape is stored. If a list, then each corresponds to 168 one dimension of the shape. 169 shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is 170 reshaped accordingly. 171 default_value: The value used when the `tensor_key` is not found in a 172 particular `TFExample`. 173 174 Raises: 175 ValueError: if both `shape_keys` and `shape` are specified. 176 """ 177 if shape_keys and shape is not None: 178 raise ValueError('Cannot specify both shape_keys and shape parameters.') 179 if shape_keys and not isinstance(shape_keys, list): 180 shape_keys = [shape_keys] 181 self._tensor_key = tensor_key 182 self._shape_keys = shape_keys 183 self._shape = shape 184 self._default_value = default_value 185 keys = [tensor_key] 186 if shape_keys: 187 keys.extend(shape_keys) 188 super(Tensor, self).__init__(keys) 189 190 def tensors_to_item(self, keys_to_tensors): 191 tensor = keys_to_tensors[self._tensor_key] 192 shape = self._shape 193 if self._shape_keys: 194 shape_dims = [] 195 for k in self._shape_keys: 196 shape_dim = keys_to_tensors[k] 197 if isinstance(shape_dim, sparse_tensor.SparseTensor): 198 shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim) 199 shape_dims.append(shape_dim) 200 shape = array_ops.reshape(array_ops.stack(shape_dims), [-1]) 201 if isinstance(tensor, sparse_tensor.SparseTensor): 202 if shape is not None: 203 tensor = sparse_ops.sparse_reshape(tensor, shape) 204 tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value) 205 else: 206 if shape is not None: 207 tensor = array_ops.reshape(tensor, shape) 208 return tensor 209 210 211class LookupTensor(Tensor): 212 """An ItemHandler that returns a parsed Tensor, the result of a lookup.""" 213 214 def __init__(self, 215 tensor_key, 216 table, 217 shape_keys=None, 218 shape=None, 219 default_value=''): 220 """Initializes the LookupTensor handler. 221 222 See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup. 223 224 Args: 225 tensor_key: the name of the `TFExample` feature to read the tensor from. 226 table: A tf.lookup table. 227 shape_keys: Optional name or list of names of the TF-Example feature in 228 which the tensor shape is stored. If a list, then each corresponds to 229 one dimension of the shape. 230 shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is 231 reshaped accordingly. 232 default_value: The value used when the `tensor_key` is not found in a 233 particular `TFExample`. 234 235 Raises: 236 ValueError: if both `shape_keys` and `shape` are specified. 237 """ 238 self._table = table 239 super(LookupTensor, self).__init__(tensor_key, shape_keys, shape, 240 default_value) 241 242 def tensors_to_item(self, keys_to_tensors): 243 unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors) 244 return self._table.lookup(unmapped_tensor) 245 246 247class BackupHandler(ItemHandler): 248 """An ItemHandler that tries two ItemHandlers in order.""" 249 250 def __init__(self, handler, backup): 251 """Initializes the BackupHandler handler. 252 253 If the first Handler's tensors_to_item returns a Tensor with no elements, 254 the second Handler is used. 255 256 Args: 257 handler: The primary ItemHandler. 258 backup: The backup ItemHandler. 259 260 Raises: 261 ValueError: if either is not an ItemHandler. 262 """ 263 if not isinstance(handler, ItemHandler): 264 raise ValueError('Primary handler is of type %s instead of ItemHandler' 265 % type(handler)) 266 if not isinstance(backup, ItemHandler): 267 raise ValueError('Backup handler is of type %s instead of ItemHandler' 268 % type(backup)) 269 self._handler = handler 270 self._backup = backup 271 super(BackupHandler, self).__init__(handler.keys + backup.keys) 272 273 def tensors_to_item(self, keys_to_tensors): 274 item = self._handler.tensors_to_item(keys_to_tensors) 275 return control_flow_ops.cond( 276 pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0), 277 true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors), 278 false_fn=lambda: item) 279 280 281class SparseTensor(ItemHandler): 282 """An ItemHandler for SparseTensors.""" 283 284 def __init__(self, 285 indices_key=None, 286 values_key=None, 287 shape_key=None, 288 shape=None, 289 densify=False, 290 default_value=0): 291 """Initializes the Tensor handler. 292 293 Args: 294 indices_key: the name of the TF-Example feature that contains the ids. 295 Defaults to 'indices'. 296 values_key: the name of the TF-Example feature that contains the values. 297 Defaults to 'values'. 298 shape_key: the name of the TF-Example feature that contains the shape. 299 If provided it would be used. 300 shape: the output shape of the SparseTensor. If `shape_key` is not 301 provided this `shape` would be used. 302 densify: whether to convert the SparseTensor into a dense Tensor. 303 default_value: Scalar value to set when making dense for indices not 304 specified in the `SparseTensor`. 305 """ 306 indices_key = indices_key or 'indices' 307 values_key = values_key or 'values' 308 self._indices_key = indices_key 309 self._values_key = values_key 310 self._shape_key = shape_key 311 self._shape = shape 312 self._densify = densify 313 self._default_value = default_value 314 keys = [indices_key, values_key] 315 if shape_key: 316 keys.append(shape_key) 317 super(SparseTensor, self).__init__(keys) 318 319 def tensors_to_item(self, keys_to_tensors): 320 indices = keys_to_tensors[self._indices_key] 321 values = keys_to_tensors[self._values_key] 322 if self._shape_key: 323 shape = keys_to_tensors[self._shape_key] 324 if isinstance(shape, sparse_tensor.SparseTensor): 325 shape = sparse_ops.sparse_tensor_to_dense(shape) 326 elif self._shape: 327 shape = self._shape 328 else: 329 shape = indices.dense_shape 330 indices_shape = array_ops.shape(indices.indices) 331 rank = indices_shape[1] 332 ids = math_ops.cast(indices.values, dtypes.int64) 333 indices_columns_to_preserve = array_ops.slice( 334 indices.indices, [0, 0], array_ops.stack([-1, rank - 1])) 335 new_indices = array_ops.concat( 336 [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1) 337 338 tensor = sparse_tensor.SparseTensor(new_indices, values.values, shape) 339 if self._densify: 340 tensor = sparse_ops.sparse_tensor_to_dense(tensor, self._default_value) 341 return tensor 342 343 344class Image(ItemHandler): 345 """An ItemHandler that decodes a parsed Tensor as an image.""" 346 347 def __init__(self, 348 image_key=None, 349 format_key=None, 350 shape=None, 351 channels=3, 352 dtype=dtypes.uint8, 353 repeated=False, 354 dct_method=''): 355 """Initializes the image. 356 357 Args: 358 image_key: the name of the TF-Example feature in which the encoded image 359 is stored. 360 format_key: the name of the TF-Example feature in which the image format 361 is stored. 362 shape: the output shape of the image as 1-D `Tensor` 363 [height, width, channels]. If provided, the image is reshaped 364 accordingly. If left as None, no reshaping is done. A shape should 365 be supplied only if all the stored images have the same shape. 366 channels: the number of channels in the image. 367 dtype: images will be decoded at this bit depth. Different formats 368 support different bit depths. 369 See tf.image.decode_image, 370 tf.decode_raw, 371 repeated: if False, decodes a single image. If True, decodes a 372 variable number of image strings from a 1D tensor of strings. 373 dct_method: An optional string. Defaults to empty string. It only takes 374 effect when image format is jpeg, used to specify a hint about the 375 algorithm used for jpeg decompression. Currently valid values 376 are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for 377 example, the jpeg library does not have that specific option. 378 """ 379 if not image_key: 380 image_key = 'image/encoded' 381 if not format_key: 382 format_key = 'image/format' 383 384 super(Image, self).__init__([image_key, format_key]) 385 self._image_key = image_key 386 self._format_key = format_key 387 self._shape = shape 388 self._channels = channels 389 self._dtype = dtype 390 self._repeated = repeated 391 self._dct_method = dct_method 392 393 def tensors_to_item(self, keys_to_tensors): 394 """See base class.""" 395 image_buffer = keys_to_tensors[self._image_key] 396 image_format = keys_to_tensors[self._format_key] 397 398 if self._repeated: 399 return map_fn.map_fn(lambda x: self._decode(x, image_format), 400 image_buffer, dtype=self._dtype) 401 else: 402 return self._decode(image_buffer, image_format) 403 404 def _decode(self, image_buffer, image_format): 405 """Decodes the image buffer. 406 407 Args: 408 image_buffer: The tensor representing the encoded image tensor. 409 image_format: The image format for the image in `image_buffer`. If image 410 format is `raw`, all images are expected to be in this format, otherwise 411 this op can decode a mix of `jpg` and `png` formats. 412 413 Returns: 414 A tensor that represents decoded image of self._shape, or 415 (?, ?, self._channels) if self._shape is not specified. 416 """ 417 418 def decode_image(): 419 """Decodes a image based on the headers.""" 420 return math_ops.cast( 421 image_ops.decode_image(image_buffer, channels=self._channels), 422 self._dtype) 423 424 def decode_jpeg(): 425 """Decodes a jpeg image with specified '_dct_method'.""" 426 return math_ops.cast( 427 image_ops.decode_jpeg( 428 image_buffer, 429 channels=self._channels, 430 dct_method=self._dct_method), self._dtype) 431 432 def check_jpeg(): 433 """Checks if an image is jpeg.""" 434 # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image 435 # in order to feed the jpeg specify parameter 'dct_method'. 436 return control_flow_ops.cond( 437 image_ops.is_jpeg(image_buffer), 438 decode_jpeg, 439 decode_image, 440 name='cond_jpeg') 441 442 def decode_raw(): 443 """Decodes a raw image.""" 444 return parsing_ops.decode_raw(image_buffer, out_type=self._dtype) 445 446 pred_fn_pairs = { 447 math_ops.logical_or( 448 math_ops.equal(image_format, 'raw'), 449 math_ops.equal(image_format, 'RAW')): decode_raw, 450 } 451 image = control_flow_ops.case( 452 pred_fn_pairs, default=check_jpeg, exclusive=True) 453 454 image.set_shape([None, None, self._channels]) 455 if self._shape is not None: 456 image = array_ops.reshape(image, self._shape) 457 458 return image 459 460 461class TFExampleDecoder(data_decoder.DataDecoder): 462 """A decoder for TensorFlow Examples. 463 464 Decoding Example proto buffers is comprised of two stages: (1) Example parsing 465 and (2) tensor manipulation. 466 467 In the first stage, the tf.parse_example function is called with a list of 468 FixedLenFeatures and SparseLenFeatures. These instances tell TF how to parse 469 the example. The output of this stage is a set of tensors. 470 471 In the second stage, the resulting tensors are manipulated to provide the 472 requested 'item' tensors. 473 474 To perform this decoding operation, an ExampleDecoder is given a list of 475 ItemHandlers. Each ItemHandler indicates the set of features for stage 1 and 476 contains the instructions for post_processing its tensors for stage 2. 477 """ 478 479 def __init__(self, keys_to_features, items_to_handlers): 480 """Constructs the decoder. 481 482 Args: 483 keys_to_features: a dictionary from TF-Example keys to either 484 tf.VarLenFeature or tf.FixedLenFeature instances. See tensorflow's 485 parsing_ops.py. 486 items_to_handlers: a dictionary from items (strings) to ItemHandler 487 instances. Note that the ItemHandler's are provided the keys that they 488 use to return the final item Tensors. 489 """ 490 self._keys_to_features = keys_to_features 491 self._items_to_handlers = items_to_handlers 492 493 def list_items(self): 494 """See base class.""" 495 return list(self._items_to_handlers.keys()) 496 497 def decode(self, serialized_example, items=None): 498 """Decodes the given serialized TF-example. 499 500 Args: 501 serialized_example: a serialized TF-example tensor. 502 items: the list of items to decode. These must be a subset of the item 503 keys in self._items_to_handlers. If `items` is left as None, then all 504 of the items in self._items_to_handlers are decoded. 505 506 Returns: 507 the decoded items, a list of tensor. 508 """ 509 example = parsing_ops.parse_single_example(serialized_example, 510 self._keys_to_features) 511 512 # Reshape non-sparse elements just once, adding the reshape ops in 513 # deterministic order. 514 for k in sorted(self._keys_to_features): 515 v = self._keys_to_features[k] 516 if isinstance(v, parsing_ops.FixedLenFeature): 517 example[k] = array_ops.reshape(example[k], v.shape) 518 519 if not items: 520 items = self._items_to_handlers.keys() 521 522 outputs = [] 523 for item in items: 524 handler = self._items_to_handlers[item] 525 keys_to_tensors = {key: example[key] for key in handler.keys} 526 outputs.append(handler.tensors_to_item(keys_to_tensors)) 527 return outputs 528