1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops to manipulate lists of tensors.""" 16 17# pylint: disable=g-bad-name 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import gen_list_ops 27# go/tf-wildcard-import 28# pylint: disable=wildcard-import 29from tensorflow.python.ops.gen_list_ops import * 30# pylint: enable=wildcard-import 31from tensorflow.python.util.lazy_loader import LazyLoader 32 33# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops 34control_flow_ops = LazyLoader( 35 "control_flow_ops", globals(), 36 "tensorflow.python.ops.control_flow_ops") 37 38 39ops.NotDifferentiable("TensorListConcatLists") 40ops.NotDifferentiable("TensorListElementShape") 41ops.NotDifferentiable("TensorListLength") 42ops.NotDifferentiable("TensorListPushBackBatch") 43 44 45def empty_tensor_list(element_shape, 46 element_dtype, 47 max_num_elements=None, 48 name=None): 49 if max_num_elements is None: 50 max_num_elements = -1 51 52 return gen_list_ops.empty_tensor_list( 53 element_shape=_build_element_shape(element_shape), 54 element_dtype=element_dtype, 55 max_num_elements=max_num_elements, 56 name=name) 57 58 59def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None): 60 return gen_list_ops.tensor_list_reserve( 61 element_shape=_build_element_shape(element_shape), 62 num_elements=num_elements, 63 element_dtype=element_dtype, 64 name=name) 65 66 67def tensor_list_from_tensor(tensor, element_shape, name=None): 68 return gen_list_ops.tensor_list_from_tensor( 69 tensor=tensor, 70 element_shape=_build_element_shape(element_shape), 71 name=name) 72 73 74def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None, 75 name=None): 76 return gen_list_ops.tensor_list_get_item( 77 input_handle=input_handle, 78 index=index, 79 element_shape=_build_element_shape(element_shape), 80 element_dtype=element_dtype, 81 name=name) 82 83 84def tensor_list_pop_back(input_handle, element_dtype, name=None): 85 return gen_list_ops.tensor_list_pop_back( 86 input_handle=input_handle, 87 element_shape=-1, 88 element_dtype=element_dtype, 89 name=name) 90 91 92def tensor_list_gather(input_handle, 93 indices, 94 element_dtype, 95 element_shape=None, 96 name=None): 97 return gen_list_ops.tensor_list_gather( 98 input_handle=input_handle, 99 indices=indices, 100 element_shape=_build_element_shape(element_shape), 101 element_dtype=element_dtype, 102 name=name) 103 104 105def tensor_list_scatter(tensor, 106 indices, 107 element_shape=None, 108 input_handle=None, 109 name=None): 110 if input_handle is not None: 111 return gen_list_ops.tensor_list_scatter_into_existing_list( 112 input_handle=input_handle, tensor=tensor, indices=indices, name=name) 113 else: 114 return gen_list_ops.tensor_list_scatter_v2( 115 tensor=tensor, 116 indices=indices, 117 element_shape=_build_element_shape(element_shape), 118 num_elements=-1, 119 name=name) 120 121 122def tensor_list_stack(input_handle, 123 element_dtype, 124 num_elements=-1, 125 element_shape=None, 126 name=None): 127 return gen_list_ops.tensor_list_stack( 128 input_handle=input_handle, 129 element_shape=_build_element_shape(element_shape), 130 element_dtype=element_dtype, 131 num_elements=num_elements, 132 name=name) 133 134 135def tensor_list_concat(input_handle, element_dtype, element_shape=None, 136 name=None): 137 # Ignore the lengths output of TensorListConcat. It is only used during 138 # gradient computation. 139 return gen_list_ops.tensor_list_concat_v2( 140 input_handle=input_handle, 141 element_dtype=element_dtype, 142 element_shape=_build_element_shape(element_shape), 143 leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64), 144 name=name)[0] 145 146 147def tensor_list_split(tensor, element_shape, lengths, name=None): 148 return gen_list_ops.tensor_list_split( 149 tensor=tensor, 150 element_shape=_build_element_shape(element_shape), 151 lengths=lengths, 152 name=name) 153 154 155def tensor_list_set_item(input_handle, 156 index, 157 item, 158 resize_if_index_out_of_bounds=False, 159 name=None): 160 """Sets `item` at `index` in input list.""" 161 if resize_if_index_out_of_bounds: 162 input_list_size = gen_list_ops.tensor_list_length(input_handle) 163 # TODO(srbs): This could cause some slowdown. Consider fusing resize 164 # functionality in the SetItem op. 165 input_handle = control_flow_ops.cond( 166 index >= input_list_size, 167 lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda 168 input_handle, index + 1), 169 lambda: input_handle) 170 return gen_list_ops.tensor_list_set_item( 171 input_handle=input_handle, index=index, item=item, name=name) 172 173 174@ops.RegisterGradient("TensorListPushBack") 175def _PushBackGrad(op, dresult): 176 return gen_list_ops.tensor_list_pop_back( 177 dresult, 178 element_shape=array_ops.shape(op.inputs[1]), 179 element_dtype=op.get_attr("element_dtype")) 180 181 182@ops.RegisterGradient("TensorListPopBack") 183def _PopBackGrad(op, dlist, delement): 184 if dlist is None: 185 dlist = empty_tensor_list( 186 element_dtype=delement.dtype, 187 element_shape=gen_list_ops.tensor_list_element_shape( 188 op.outputs[0], shape_type=dtypes.int32)) 189 return gen_list_ops.tensor_list_push_back(dlist, delement), None 190 191 192@ops.RegisterGradient("TensorListStack") 193def _TensorListStackGrad(unused_op, dtensor): 194 return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None 195 196 197@ops.RegisterGradient("TensorListConcat") 198@ops.RegisterGradient("TensorListConcatV2") 199def _TensorListConcatGrad(op, dtensor, unused_dlengths): 200 """Gradient function for TensorListConcat.""" 201 dlist = tensor_list_split( 202 dtensor, 203 element_shape=gen_list_ops.tensor_list_element_shape( 204 op.inputs[0], shape_type=dtypes.int32), 205 lengths=op.outputs[1]) 206 if op.type == "TensorListConcatV2": 207 return dlist, None, None 208 else: 209 return dlist 210 211 212@ops.RegisterGradient("TensorListSplit") 213def _TensorListSplitGrad(op, dlist): 214 tensor, _, lengths = op.inputs 215 element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1]) 216 element_shape = array_ops.concat([[-1], element_shape], axis=0) 217 return gen_list_ops.tensor_list_concat_v2( 218 dlist, 219 element_shape=element_shape, 220 leading_dims=lengths, 221 element_dtype=op.inputs[0].dtype)[0], None, None 222 223 224@ops.RegisterGradient("TensorListFromTensor") 225def _TensorListFromTensorGrad(op, dlist): 226 """Gradient for TensorListFromTensor.""" 227 t = op.inputs[0] 228 if t.shape.dims and t.shape.dims[0].value is not None: 229 num_elements = t.shape.dims[0].value 230 else: 231 num_elements = None 232 if dlist is None: 233 dlist = empty_tensor_list( 234 element_dtype=t.dtype, 235 element_shape=gen_list_ops.tensor_list_element_shape( 236 op.outputs[0], shape_type=dtypes.int32)) 237 tensor_grad = gen_list_ops.tensor_list_stack( 238 dlist, 239 element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]), 240 element_dtype=t.dtype, 241 num_elements=num_elements) 242 shape_grad = None 243 return tensor_grad, shape_grad 244 245 246@ops.RegisterGradient("TensorListGetItem") 247def _TensorListGetItemGrad(op, ditem): 248 """Gradient for TensorListGetItem.""" 249 list_size = gen_list_ops.tensor_list_length(op.inputs[0]) 250 list_grad = gen_list_ops.tensor_list_set_item( 251 gen_list_ops.tensor_list_reserve( 252 gen_list_ops.tensor_list_element_shape(op.inputs[0], 253 shape_type=dtypes.int32), 254 list_size, element_dtype=ditem.dtype), 255 index=op.inputs[1], 256 item=ditem) 257 index_grad = None 258 element_shape_grad = None 259 return list_grad, index_grad, element_shape_grad 260 261 262@ops.RegisterGradient("TensorListSetItem") 263def _TensorListSetItemGrad(op, dlist): 264 """Gradient function for TensorListSetItem.""" 265 _, index, item = op.inputs 266 list_grad = gen_list_ops.tensor_list_set_item( 267 dlist, index=index, item=array_ops.zeros_like(item)) 268 index_grad = None 269 element_grad = tensor_list_get_item( 270 dlist, 271 index, 272 element_shape=array_ops.shape(item), 273 element_dtype=item.dtype) 274 return list_grad, index_grad, element_grad 275 276 277@ops.RegisterGradient("TensorListResize") 278def _TensorListResizeGrad(op, dlist): 279 input_list, _ = op.inputs 280 input_list_size = gen_list_ops.tensor_list_length(input_list) 281 return gen_list_ops.tensor_list_resize(dlist, input_list_size), None 282 283 284@ops.RegisterGradient("TensorListGather") 285def _TensorListGatherGrad(op, dtensor): 286 """Gradient function for TensorListGather.""" 287 input_list, indices, _ = op.inputs 288 element_shape = gen_list_ops.tensor_list_element_shape( 289 input_list, shape_type=dtypes.int32) 290 num_elements = gen_list_ops.tensor_list_length(input_list) 291 dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype) 292 dlist = tensor_list_scatter( 293 tensor=dtensor, indices=indices, input_handle=dlist) 294 return dlist, None, None 295 296 297@ops.RegisterGradient("TensorListScatter") 298@ops.RegisterGradient("TensorListScatterV2") 299def _TensorListScatterGrad(op, dlist): 300 """Gradient function for TensorListScatter.""" 301 tensor = op.inputs[0] 302 indices = op.inputs[1] 303 dtensor = gen_list_ops.tensor_list_gather( 304 dlist, 305 indices, 306 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 307 element_dtype=tensor.dtype) 308 if op.type == "TensorListScatterV2": 309 return dtensor, None, None, None 310 else: 311 return dtensor, None, None 312 313 314@ops.RegisterGradient("TensorListScatterIntoExistingList") 315def _TensorListScatterIntoExistingListGrad(op, dlist): 316 """Gradient function for TensorListScatterIntoExistingList.""" 317 _, tensor, indices = op.inputs 318 dtensor = gen_list_ops.tensor_list_gather( 319 dlist, 320 indices, 321 element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]), 322 element_dtype=tensor.dtype) 323 zeros = array_ops.zeros_like(tensor) 324 dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist) 325 return dlist, dtensor, None 326 327 328def _build_element_shape(shape): 329 """Converts shape to a format understood by list_ops for element_shape. 330 331 If `shape` is already a `Tensor` it is returned as-is. We do not perform a 332 type check here. 333 334 If shape is None or a TensorShape with unknown rank, -1 is returned. 335 336 If shape is a scalar, an int32 tensor with empty list is returned. Note we 337 do directly return an empty list since ops.convert_to_tensor would conver it 338 to a float32 which is not a valid type for element_shape. 339 340 If shape is a sequence of dims, None's in the list are replaced with -1. We 341 do not check the dtype of the other dims. 342 343 Args: 344 shape: Could be None, Tensor, TensorShape or a list of dims (each dim could 345 be a None, scalar or Tensor). 346 347 Returns: 348 A None-free shape that can be converted to a tensor. 349 """ 350 if isinstance(shape, ops.Tensor): 351 return shape 352 if isinstance(shape, tensor_shape.TensorShape): 353 # `TensorShape.as_list` requires rank to be known. 354 shape = shape.as_list() if shape else None 355 # Shape is unknown. 356 if shape is None: 357 return -1 358 # Shape is a scalar. 359 if not shape: 360 return ops.convert_to_tensor(shape, dtype=dtypes.int32) 361 # Shape is a sequence of dimensions. Convert None dims to -1. 362 def convert(val): 363 if val is None: 364 return -1 365 if isinstance(val, ops.Tensor): 366 return val 367 if isinstance(val, tensor_shape.Dimension): 368 return val.value if val.value is not None else -1 369 return val 370 371 return [convert(d) for d in shape] 372