1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to manipulate lists of tensors.""" 16 17# pylint: disable=g-bad-name 18import numpy as np 19 20from tensorflow.core.framework import full_type_pb2 21from tensorflow.python.framework import cpp_shape_inference_pb2 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_list_ops 28from tensorflow.python.ops import handle_data_util 29# go/tf-wildcard-import 30# pylint: disable=wildcard-import 31from tensorflow.python.ops.gen_list_ops import * 32# pylint: enable=wildcard-import 33from tensorflow.python.util.lazy_loader import LazyLoader 34 35# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops 36control_flow_ops = LazyLoader( 37 "control_flow_ops", globals(), 38 "tensorflow.python.ops.control_flow_ops") 39 40 41ops.NotDifferentiable("TensorListConcatLists") 42ops.NotDifferentiable("TensorListElementShape") 43ops.NotDifferentiable("TensorListLength") 44ops.NotDifferentiable("TensorListPushBackBatch") 45 46 47def empty_tensor_list(element_shape, 48 element_dtype, 49 max_num_elements=None, 50 name=None): 51 if max_num_elements is None: 52 max_num_elements = -1 53 54 return gen_list_ops.empty_tensor_list( 55 element_shape=_build_element_shape(element_shape), 56 element_dtype=element_dtype, 57 max_num_elements=max_num_elements, 58 name=name) 59 60 61def _set_handle_data(list_handle, element_shape, element_dtype): 62 """Sets type information on `list_handle` for consistency with graphs.""" 63 # TODO(b/169968286): It would be better if we had a consistent story for 64 # creating handle data from eager operations (shared with VarHandleOp). 65 if isinstance(list_handle, ops.EagerTensor): 66 if tensor_util.is_tf_type(element_shape): 67 element_shape = tensor_shape.TensorShape(None) 68 elif not isinstance(element_shape, tensor_shape.TensorShape): 69 element_shape = tensor_shape.TensorShape(element_shape) 70 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 71 handle_data.is_set = True 72 # TODO(b/191472076): This duplicates type inference. Clean up. 73 handle_data.shape_and_type.append( 74 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 75 shape=element_shape.as_proto(), 76 dtype=element_dtype.as_datatype_enum, 77 type=full_type_pb2.FullTypeDef(type_id=full_type_pb2.TFT_ARRAY))) 78 list_handle._handle_data = handle_data # pylint: disable=protected-access 79 80 81def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): 82 result = gen_list_ops.tensor_list_reserve( 83 element_shape=_build_element_shape(element_shape), 84 num_elements=num_elements, 85 element_dtype=element_dtype, 86 name=name) 87 # TODO(b/169968286): gen_ops needs to ensure the metadata is properly 88 # populated for eager operations. 89 _set_handle_data(result, element_shape, element_dtype) 90 return result 91 92 93def tensor_list_from_tensor(tensor, element_shape, name=None): 94 tensor = ops.convert_to_tensor(tensor) 95 result = gen_list_ops.tensor_list_from_tensor( 96 tensor=tensor, 97 element_shape=_build_element_shape(element_shape), 98 name=name) 99 _set_handle_data(result, tensor.shape, tensor.dtype) 100 return result 101 102 103def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, 104 name=None): 105 return gen_list_ops.tensor_list_get_item( 106 input_handle=input_handle, 107 index=index, 108 element_shape=_build_element_shape(element_shape), 109 element_dtype=element_dtype, 110 name=name) 111 112 113def tensor_list_pop_back(input_handle, element_dtype, name=None): 114 return gen_list_ops.tensor_list_pop_back( 115 input_handle=input_handle, 116 element_shape=-1, 117 element_dtype=element_dtype, 118 name=name) 119 120 121def tensor_list_gather(input_handle, 122 indices, 123 element_dtype, 124 element_shape=None, 125 name=None): 126 return gen_list_ops.tensor_list_gather( 127 input_handle=input_handle, 128 indices=indices, 129 element_shape=_build_element_shape(element_shape), 130 element_dtype=element_dtype, 131 name=name) 132 133 134def tensor_list_scatter(tensor, 135 indices, 136 element_shape=None, 137 input_handle=None, 138 name=None): 139 """Returns a TensorList created or updated by scattering `tensor`.""" 140 tensor = ops.convert_to_tensor(tensor) 141 if input_handle is not None: 142 output_handle = gen_list_ops.tensor_list_scatter_into_existing_list( 143 input_handle=input_handle, tensor=tensor, indices=indices, name=name) 144 handle_data_util.copy_handle_data(input_handle, output_handle) 145 return output_handle 146 else: 147 output_handle = gen_list_ops.tensor_list_scatter_v2( 148 tensor=tensor, 149 indices=indices, 150 element_shape=_build_element_shape(element_shape), 151 num_elements=-1, 152 name=name) 153 _set_handle_data(output_handle, element_shape, tensor.dtype) 154 return output_handle 155 156 157def tensor_list_stack(input_handle, 158 element_dtype, 159 num_elements=-1, 160 element_shape=None, 161 name=None): 162 return gen_list_ops.tensor_list_stack( 163 input_handle=input_handle, 164 element_shape=_build_element_shape(element_shape), 165 element_dtype=element_dtype, 166 num_elements=num_elements, 167 name=name) 168 169 170def tensor_list_concat(input_handle, element_dtype, element_shape=None, 171 name=None): 172 # Ignore the lengths output of TensorListConcat. It is only used during 173 # gradient computation. 174 return gen_list_ops.tensor_list_concat_v2( 175 input_handle=input_handle, 176 element_dtype=element_dtype, 177 element_shape=_build_element_shape(element_shape), 178 leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64), 179 name=name)[0] 180 181 182def tensor_list_split(tensor, element_shape, lengths, name=None): 183 return gen_list_ops.tensor_list_split( 184 tensor=tensor, 185 element_shape=_build_element_shape(element_shape), 186 lengths=lengths, 187 name=name) 188 189 190def tensor_list_set_item(input_handle, 191 index, 192 item, 193 resize_if_index_out_of_bounds=False, 194 name=None): 195 """Sets `item` at `index` in input list.""" 196 if resize_if_index_out_of_bounds: 197 input_list_size = gen_list_ops.tensor_list_length(input_handle) 198 # TODO(srbs): This could cause some slowdown. Consider fusing resize 199 # functionality in the SetItem op. 200 input_handle = control_flow_ops.cond( 201 index >= input_list_size, 202 lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda 203 input_handle, index + 1), 204 lambda: input_handle) 205 output_handle = gen_list_ops.tensor_list_set_item( 206 input_handle=input_handle, index=index, item=item, name=name) 207 handle_data_util.copy_handle_data(input_handle, output_handle) 208 return output_handle 209 210 211@ops.RegisterGradient("TensorListPushBack") 212def _PushBackGrad(op, dresult): 213 return gen_list_ops.tensor_list_pop_back( 214 dresult, 215 element_shape=array_ops.shape(op.inputs[1]), 216 element_dtype=op.get_attr("element_dtype")) 217 218 219@ops.RegisterGradient("TensorListPopBack") 220def _PopBackGrad(op, dlist, delement): 221 if dlist is None: 222 dlist = empty_tensor_list( 223 element_dtype=delement.dtype, 224 element_shape=gen_list_ops.tensor_list_element_shape( 225 op.outputs[0], shape_type=dtypes.int32)) 226 if delement is None: 227 delement = array_ops.zeros_like(op.outputs[1]) 228 return gen_list_ops.tensor_list_push_back(dlist, delement), None 229 230 231@ops.RegisterGradient("TensorListStack") 232def _TensorListStackGrad(unused_op, dtensor): 233 return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None 234 235 236@ops.RegisterGradient("TensorListConcat") 237@ops.RegisterGradient("TensorListConcatV2") 238def _TensorListConcatGrad(op, dtensor, unused_dlengths): 239 """Gradient function for TensorListConcat.""" 240 dlist = tensor_list_split( 241 dtensor, 242 element_shape=gen_list_ops.tensor_list_element_shape( 243 op.inputs[0], shape_type=dtypes.int32), 244 lengths=op.outputs[1]) 245 if op.type == "TensorListConcatV2": 246 return dlist, None, None 247 else: 248 return dlist 249 250 251@ops.RegisterGradient("TensorListSplit") 252def _TensorListSplitGrad(op, dlist): 253 tensor, _, lengths = op.inputs 254 element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1]) 255 element_shape = array_ops.concat([[-1], element_shape], axis=0) 256 return gen_list_ops.tensor_list_concat_v2( 257 dlist, 258 element_shape=element_shape, 259 leading_dims=lengths, 260 element_dtype=op.inputs[0].dtype)[0], None, None 261 262 263@ops.RegisterGradient("TensorListFromTensor") 264def _TensorListFromTensorGrad(op, dlist): 265 """Gradient for TensorListFromTensor.""" 266 t = op.inputs[0] 267 if t.shape.dims and t.shape.dims[0].value is not None: 268 num_elements = t.shape.dims[0].value 269 else: 270 num_elements = None 271 if dlist is None: 272 dlist = empty_tensor_list( 273 element_dtype=t.dtype, 274 element_shape=gen_list_ops.tensor_list_element_shape( 275 op.outputs[0], shape_type=dtypes.int32)) 276 tensor_grad = gen_list_ops.tensor_list_stack( 277 dlist, 278 element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]), 279 element_dtype=t.dtype, 280 num_elements=num_elements) 281 shape_grad = None 282 return tensor_grad, shape_grad 283 284 285@ops.RegisterGradient("TensorListGetItem") 286def _TensorListGetItemGrad(op, ditem): 287 """Gradient for TensorListGetItem.""" 288 list_size = gen_list_ops.tensor_list_length(op.inputs[0]) 289 list_grad = gen_list_ops.tensor_list_set_item( 290 gen_list_ops.tensor_list_reserve( 291 gen_list_ops.tensor_list_element_shape(op.inputs[0], 292 shape_type=dtypes.int32), 293 list_size, element_dtype=ditem.dtype), 294 index=op.inputs[1], 295 item=ditem) 296 index_grad = None 297 element_shape_grad = None 298 return list_grad, index_grad, element_shape_grad 299 300 301@ops.RegisterGradient("TensorListSetItem") 302def _TensorListSetItemGrad(op, dlist): 303 """Gradient function for TensorListSetItem.""" 304 _, index, item = op.inputs 305 list_grad = gen_list_ops.tensor_list_set_item( 306 dlist, index=index, item=array_ops.zeros_like(item)) 307 index_grad = None 308 element_grad = tensor_list_get_item( 309 dlist, 310 index, 311 element_shape=array_ops.shape(item), 312 element_dtype=item.dtype) 313 return list_grad, index_grad, element_grad 314 315 316@ops.RegisterGradient("TensorListResize") 317def _TensorListResizeGrad(op, dlist): 318 input_list, _ = op.inputs 319 input_list_size = gen_list_ops.tensor_list_length(input_list) 320 return gen_list_ops.tensor_list_resize(dlist, input_list_size), None 321 322 323@ops.RegisterGradient("TensorListGather") 324def _TensorListGatherGrad(op, dtensor): 325 """Gradient function for TensorListGather.""" 326 input_list, indices, _ = op.inputs 327 element_shape = gen_list_ops.tensor_list_element_shape( 328 input_list, shape_type=dtypes.int32) 329 num_elements = gen_list_ops.tensor_list_length(input_list) 330 dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype) 331 dlist = tensor_list_scatter( 332 tensor=dtensor, indices=indices, input_handle=dlist) 333 return dlist, None, None 334 335 336@ops.RegisterGradient("TensorListScatter") 337@ops.RegisterGradient("TensorListScatterV2") 338def _TensorListScatterGrad(op, dlist): 339 """Gradient function for TensorListScatter.""" 340 tensor = op.inputs[0] 341 indices = op.inputs[1] 342 dtensor = gen_list_ops.tensor_list_gather( 343 dlist, 344 indices, 345 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 346 element_dtype=tensor.dtype) 347 if op.type == "TensorListScatterV2": 348 return dtensor, None, None, None 349 else: 350 return dtensor, None, None 351 352 353@ops.RegisterGradient("TensorListScatterIntoExistingList") 354def _TensorListScatterIntoExistingListGrad(op, dlist): 355 """Gradient function for TensorListScatterIntoExistingList.""" 356 _, tensor, indices = op.inputs 357 dtensor = gen_list_ops.tensor_list_gather( 358 dlist, 359 indices, 360 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 361 element_dtype=tensor.dtype) 362 zeros = array_ops.zeros_like(tensor) 363 dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist) 364 return dlist, dtensor, None 365 366 367def _build_element_shape(shape): 368 """Converts shape to a format understood by list_ops for element_shape. 369 370 If `shape` is already a `Tensor` it is returned as-is. We do not perform a 371 type check here. 372 373 If shape is None or a TensorShape with unknown rank, -1 is returned. 374 375 If shape is a scalar, an int32 tensor with empty list is returned. Note we 376 do directly return an empty list since ops.convert_to_tensor would conver it 377 to a float32 which is not a valid type for element_shape. 378 379 If shape is a sequence of dims, None's in the list are replaced with -1. We 380 do not check the dtype of the other dims. 381 382 Args: 383 shape: Could be None, Tensor, TensorShape or a list of dims (each dim could 384 be a None, scalar or Tensor). 385 386 Returns: 387 A None-free shape that can be converted to a tensor. 388 """ 389 if isinstance(shape, ops.Tensor): 390 return shape 391 if isinstance(shape, tensor_shape.TensorShape): 392 # `TensorShape.as_list` requires rank to be known. 393 shape = shape.as_list() if shape else None 394 # Shape is unknown. 395 if shape is None: 396 return -1 397 # Shape is numpy array or a scalar. 398 if isinstance(shape, (np.ndarray, np.generic)) or not shape: 399 return ops.convert_to_tensor(shape, dtype=dtypes.int32) 400 # Shape is a sequence of dimensions. Convert None dims to -1. 401 def convert(val): 402 if val is None: 403 return -1 404 if isinstance(val, ops.Tensor): 405 return val 406 if isinstance(val, tensor_shape.Dimension): 407 return val.value if val.value is not None else -1 408 return val 409 410 return [convert(d) for d in shape] 411