1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operators specific to data structures: list append, subscripts, etc.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import list_ops 30from tensorflow.python.ops import tensor_array_ops 31 32 33# TODO(mdan): Once control flow supports objects, repackage as a class. 34 35 36def new_list(iterable=None): 37 """The list constructor. 38 39 Args: 40 iterable: Optional elements to fill the list with. 41 42 Returns: 43 A list-like object. The exact return value depends on the initial elements. 44 """ 45 if iterable: 46 elements = tuple(iterable) 47 else: 48 elements = () 49 50 if elements: 51 # When the list contains elements, it is assumed to be a "Python" lvalue 52 # list. 53 return _py_list_new(elements) 54 return tf_tensor_list_new(elements) 55 56 57def tf_tensor_array_new(elements, element_dtype=None, element_shape=None): 58 """Overload of new_list that stages a Tensor list creation.""" 59 elements = tuple(ops.convert_to_tensor(el) for el in elements) 60 61 all_dtypes = set(el.dtype for el in elements) 62 if len(all_dtypes) == 1: 63 inferred_dtype, = tuple(all_dtypes) 64 if element_dtype is not None and element_dtype != inferred_dtype: 65 raise ValueError( 66 'incompatible dtype; specified: {}, inferred from {}: {}'.format( 67 element_dtype, elements, inferred_dtype)) 68 elif len(all_dtypes) > 1: 69 raise ValueError( 70 'TensorArray requires all elements to have the same dtype:' 71 ' {}'.format(elements)) 72 else: 73 if element_dtype is None: 74 raise ValueError('dtype is required to create an empty TensorArray') 75 76 all_shapes = set(tuple(el.shape.as_list()) for el in elements) 77 if len(all_shapes) == 1: 78 inferred_shape, = tuple(all_shapes) 79 if element_shape is not None and element_shape != inferred_shape: 80 raise ValueError( 81 'incompatible shape; specified: {}, inferred from {}: {}'.format( 82 element_shape, elements, inferred_shape)) 83 elif len(all_shapes) > 1: 84 raise ValueError( 85 'TensorArray requires all elements to have the same shape:' 86 ' {}'.format(elements)) 87 # TODO(mdan): We may want to allow different shapes with infer_shape=False. 88 else: 89 inferred_shape = None 90 91 if element_dtype is None: 92 element_dtype = inferred_dtype 93 if element_shape is None: 94 element_shape = inferred_shape 95 96 l = tensor_array_ops.TensorArray( 97 dtype=element_dtype, 98 size=len(elements), 99 dynamic_size=True, 100 infer_shape=(element_shape is None), 101 element_shape=element_shape) 102 for i, el in enumerate(elements): 103 l = l.write(i, el) 104 return l 105 106 107def tf_tensor_list_new(elements, element_dtype=None, element_shape=None): 108 """Overload of new_list that stages a Tensor list creation.""" 109 if tensor_util.is_tensor(elements): 110 if element_shape is not None: 111 raise ValueError( 112 'element shape may not be specified when creating list from tensor') 113 element_shape = array_ops.shape(elements)[1:] 114 l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape) 115 return l 116 117 elements = tuple(ops.convert_to_tensor(el) for el in elements) 118 119 all_dtypes = set(el.dtype for el in elements) 120 if len(all_dtypes) == 1: 121 inferred_dtype = tuple(all_dtypes)[0] 122 if element_dtype is not None and element_dtype != inferred_dtype: 123 raise ValueError( 124 'incompatible dtype; specified: {}, inferred from {}: {}'.format( 125 element_dtype, elements, inferred_dtype)) 126 elif all_dtypes: 127 # Heterogeneous lists are ok. 128 if element_dtype is not None: 129 raise ValueError( 130 'specified dtype {} is inconsistent with that of elements {}'.format( 131 element_dtype, elements)) 132 inferred_dtype = dtypes.variant 133 else: 134 inferred_dtype = dtypes.variant 135 136 all_shapes = set(tuple(el.shape.as_list()) for el in elements) 137 if len(all_shapes) == 1: 138 inferred_shape = array_ops.shape(elements[0]) 139 if element_shape is not None and element_shape != inferred_shape: 140 raise ValueError( 141 'incompatible shape; specified: {}, inferred from {}: {}'.format( 142 element_shape, elements, inferred_shape)) 143 elif all_shapes: 144 # Heterogeneous lists are ok. 145 if element_shape is not None: 146 raise ValueError( 147 'specified shape {} is inconsistent with that of elements {}'.format( 148 element_shape, elements)) 149 inferred_shape = constant_op.constant(-1) # unknown shape, by convention 150 else: 151 inferred_shape = constant_op.constant(-1) # unknown shape, by convention 152 153 if element_dtype is None: 154 element_dtype = inferred_dtype 155 if element_shape is None: 156 element_shape = inferred_shape 157 158 element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32) 159 l = list_ops.empty_tensor_list( 160 element_shape=element_shape, element_dtype=element_dtype) 161 for el in elements: 162 l = list_ops.tensor_list_push_back(l, el) 163 return l 164 165 166def _py_list_new(elements): 167 """Overload of new_list that creates a Python list.""" 168 return list(elements) 169 170 171def list_append(list_, x): 172 """The list append function. 173 174 Note: it is unspecified where list_ will be mutated or not. If list_ is 175 a TensorFlow entity, it will not be typically mutated. If list_ is a plain 176 list, it will be. In general, if the list is mutated then the return value 177 should point to the original entity. 178 179 Args: 180 list_: An entity that supports append semantics. 181 x: The element to append. 182 183 Returns: 184 Same as list_, after the append was performed. 185 186 Raises: 187 ValueError: if list_ is not of a known list-like type. 188 """ 189 if isinstance(list_, tensor_array_ops.TensorArray): 190 return _tf_tensorarray_append(list_, x) 191 elif tensor_util.is_tensor(list_): 192 if list_.dtype == dtypes.variant: 193 return _tf_tensor_list_append(list_, x) 194 else: 195 raise ValueError( 196 'tensor lists are expected to be Tensors with dtype=tf.variant,' 197 ' instead found %s' % list_) 198 else: 199 return _py_list_append(list_, x) 200 201 202def _tf_tensor_list_append(list_, x): 203 """Overload of list_append that stages a Tensor list write.""" 204 def empty_list_of_elements_like_x(): 205 tensor_x = ops.convert_to_tensor(x) 206 return list_ops.empty_tensor_list( 207 element_shape=array_ops.shape(tensor_x), 208 element_dtype=tensor_x.dtype) 209 210 list_ = control_flow_ops.cond( 211 list_ops.tensor_list_length(list_) > 0, 212 lambda: list_, 213 empty_list_of_elements_like_x, 214 ) 215 return list_ops.tensor_list_push_back(list_, x) 216 217 218def _tf_tensorarray_append(list_, x): 219 """Overload of list_append that stages a TensorArray write.""" 220 return list_.write(list_.size(), x) 221 222 223def _py_list_append(list_, x): 224 """Overload of list_append that executes a Python list append.""" 225 # Revert to the original call. 226 list_.append(x) 227 return list_ 228 229 230class ListPopOpts( 231 collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))): 232 pass 233 234 235def list_pop(list_, i, opts): 236 """The list pop function. 237 238 Note: it is unspecified where list_ will be mutated or not. If list_ is 239 a TensorFlow entity, it will not be typically mutated. If list_ is a plain 240 list, it will be. In general, if the list is mutated then the return value 241 should point to the original entity. 242 243 Args: 244 list_: An entity that supports pop semantics. 245 i: Optional index to pop from. May be None. 246 opts: A ListPopOpts. 247 248 Returns: 249 Tuple (x, out_list_): 250 out_list_: same as list_, after the removal was performed. 251 x: the removed element value. 252 253 Raises: 254 ValueError: if list_ is not of a known list-like type or the operation is 255 not supported for that type. 256 """ 257 assert isinstance(opts, ListPopOpts) 258 259 if isinstance(list_, tensor_array_ops.TensorArray): 260 raise ValueError('TensorArray does not support item removal') 261 elif tensor_util.is_tensor(list_): 262 if list_.dtype == dtypes.variant: 263 return _tf_tensor_list_pop(list_, i, opts) 264 else: 265 raise ValueError( 266 'tensor lists are expected to be Tensors with dtype=tf.variant,' 267 ' instead found %s' % list_) 268 else: 269 return _py_list_pop(list_, i) 270 271 272def _tf_tensor_list_pop(list_, i, opts): 273 """Overload of list_pop that stages a Tensor list pop.""" 274 if i is not None: 275 raise NotImplementedError('tensor lists only support removing from the end') 276 277 if opts.element_dtype is None: 278 raise ValueError('cannot pop from a list without knowing its element ' 279 'type; use set_element_type to annotate it') 280 if opts.element_shape is None: 281 raise ValueError('cannot pop from a list without knowing its element ' 282 'shape; use set_element_type to annotate it') 283 list_out, x = list_ops.tensor_list_pop_back( 284 list_, element_dtype=opts.element_dtype) 285 x.set_shape(opts.element_shape) 286 return list_out, x 287 288 289def _py_list_pop(list_, i): 290 """Overload of list_pop that executes a Python list append.""" 291 if i is None: 292 x = list_.pop() 293 else: 294 x = list_.pop(i) 295 return list_, x 296 297 298# TODO(mdan): Look into reducing duplication between all these containers. 299class ListStackOpts( 300 collections.namedtuple('ListStackOpts', 301 ('element_dtype', 'original_call'))): 302 pass 303 304 305def list_stack(list_, opts): 306 """The list stack function. 307 308 This does not have a direct correspondent in Python. The closest idiom to 309 this is tf.append or np.stack. It's different from those in the sense that it 310 accepts a Tensor list, rather than a list of tensors. It can also accept 311 TensorArray. When the target is anything else, the dispatcher will rely on 312 ctx.original_call for fallback. 313 314 Args: 315 list_: An entity that supports append semantics. 316 opts: A ListStackOpts object. 317 318 Returns: 319 The output of the stack operation, typically a Tensor. 320 """ 321 assert isinstance(opts, ListStackOpts) 322 323 if isinstance(list_, tensor_array_ops.TensorArray): 324 return _tf_tensorarray_stack(list_) 325 elif tensor_util.is_tensor(list_): 326 if list_.dtype == dtypes.variant: 327 return _tf_tensor_list_stack(list_, opts) 328 else: 329 # No-op for primitive Tensor arguments. 330 return list_ 331 else: 332 return _py_list_stack(list_, opts) 333 334 335def _tf_tensorarray_stack(list_): 336 """Overload of list_stack that stages a TensorArray stack.""" 337 return list_.stack() 338 339 340def _tf_tensor_list_stack(list_, opts): 341 """Overload of list_stack that stages a Tensor list write.""" 342 if opts.element_dtype is None: 343 raise ValueError('cannot stack a list without knowing its element type;' 344 ' use set_element_type to annotate it') 345 return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype) 346 347 348def _py_list_stack(list_, opts): 349 """Overload of list_stack that executes a Python list append.""" 350 # Revert to the original call. 351 return opts.original_call(list_) 352