1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Seq2seq layer operations for use in neural networks.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import six 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_shape 29from tensorflow.python.framework import tensor_util 30from tensorflow.python.keras import layers 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import control_flow_util 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import rnn 36from tensorflow.python.ops import rnn_cell_impl 37from tensorflow.python.ops import tensor_array_ops 38from tensorflow.python.ops import variable_scope 39from tensorflow.python.util import nest 40 41 42__all__ = ["Decoder", "dynamic_decode"] 43 44 45_transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access 46_zero_state_tensors = rnn_cell_impl._zero_state_tensors # pylint: disable=protected-access 47 48 49@six.add_metaclass(abc.ABCMeta) 50class Decoder(object): 51 """An RNN Decoder abstract interface object. 52 53 Concepts used by this interface: 54 - `inputs`: (structure of) tensors and TensorArrays that is passed as input to 55 the RNNCell composing the decoder, at each time step. 56 - `state`: (structure of) tensors and TensorArrays that is passed to the 57 RNNCell instance as the state. 58 - `finished`: boolean tensor telling whether each sequence in the batch is 59 finished. 60 - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each 61 time step. 62 """ 63 64 @property 65 def batch_size(self): 66 """The batch size of input values.""" 67 raise NotImplementedError 68 69 @property 70 def output_size(self): 71 """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s].""" 72 raise NotImplementedError 73 74 @property 75 def output_dtype(self): 76 """A (possibly nested tuple of...) dtype[s].""" 77 raise NotImplementedError 78 79 @abc.abstractmethod 80 def initialize(self, name=None): 81 """Called before any decoding iterations. 82 83 This methods must compute initial input values and initial state. 84 85 Args: 86 name: Name scope for any created operations. 87 88 Returns: 89 `(finished, initial_inputs, initial_state)`: initial values of 90 'finished' flags, inputs and state. 91 """ 92 raise NotImplementedError 93 94 @abc.abstractmethod 95 def step(self, time, inputs, state, name=None): 96 """Called per step of decoding (but only once for dynamic decoding). 97 98 Args: 99 time: Scalar `int32` tensor. Current step number. 100 inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time 101 step. 102 state: RNNCell state (possibly nested tuple of) tensor[s] from previous 103 time step. 104 name: Name scope for any created operations. 105 106 Returns: 107 `(outputs, next_state, next_inputs, finished)`: `outputs` is an object 108 containing the decoder output, `next_state` is a (structure of) state 109 tensors and TensorArrays, `next_inputs` is the tensor that should be used 110 as input for the next step, `finished` is a boolean tensor telling whether 111 the sequence is complete, for each sequence in the batch. 112 """ 113 raise NotImplementedError 114 115 def finalize(self, outputs, final_state, sequence_lengths): 116 raise NotImplementedError 117 118 @property 119 def tracks_own_finished(self): 120 """Describes whether the Decoder keeps track of finished states. 121 122 Most decoders will emit a true/false `finished` value independently 123 at each time step. In this case, the `dynamic_decode` function keeps track 124 of which batch entries are already finished, and performs a logical OR to 125 insert new batches to the finished set. 126 127 Some decoders, however, shuffle batches / beams between time steps and 128 `dynamic_decode` will mix up the finished state across these entries because 129 it does not track the reshuffle across time steps. In this case, it is 130 up to the decoder to declare that it will keep track of its own finished 131 state by setting this property to `True`. 132 133 Returns: 134 Python bool. 135 """ 136 return False 137 138 139class BaseDecoder(layers.Layer): 140 """An RNN Decoder that is based on a Keras layer. 141 142 Concepts used by this interface: 143 - `inputs`: (structure of) tensors and TensorArrays that is passed as input to 144 the RNNCell composing the decoder, at each time step. 145 - `state`: (structure of) tensors and TensorArrays that is passed to the 146 RNNCell instance as the state. 147 - `memory`: (sturecute of) tensors that is usually the full output of the 148 encoder, which will be used for the attention wrapper for the RNNCell. 149 - `finished`: boolean tensor telling whether each sequence in the batch is 150 finished. 151 - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each 152 time step. 153 """ 154 155 def __init__(self, 156 output_time_major=False, 157 impute_finished=False, 158 maximum_iterations=None, 159 parallel_iterations=32, 160 swap_memory=False, 161 **kwargs): 162 self.output_time_major = output_time_major 163 self.impute_finished = impute_finished 164 self.maximum_iterations = maximum_iterations 165 self.parallel_iterations = parallel_iterations 166 self.swap_memory = swap_memory 167 super(BaseDecoder, self).__init__(**kwargs) 168 169 def call(self, inputs, initial_state=None, **kwargs): 170 init_kwargs = kwargs 171 init_kwargs["initial_state"] = initial_state 172 return dynamic_decode(self, 173 output_time_major=self.output_time_major, 174 impute_finished=self.impute_finished, 175 maximum_iterations=self.maximum_iterations, 176 parallel_iterations=self.parallel_iterations, 177 swap_memory=self.swap_memory, 178 decoder_init_input=inputs, 179 decoder_init_kwargs=init_kwargs) 180 181 @property 182 def batch_size(self): 183 """The batch size of input values.""" 184 raise NotImplementedError 185 186 @property 187 def output_size(self): 188 """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s].""" 189 raise NotImplementedError 190 191 @property 192 def output_dtype(self): 193 """A (possibly nested tuple of...) dtype[s].""" 194 raise NotImplementedError 195 196 def initialize(self, inputs, initial_state=None, **kwargs): 197 """Called before any decoding iterations. 198 199 This methods must compute initial input values and initial state. 200 201 Args: 202 inputs: (structure of) tensors that contains the input for the decoder. In 203 the normal case, its a tensor with shape [batch, timestep, embedding]. 204 initial_state: (structure of) tensors that contains the initial state for 205 the RNNCell. 206 **kwargs: Other arguments that are passed in from layer.call() method. It 207 could contains item like input sequence_length, or masking for input. 208 209 Returns: 210 `(finished, initial_inputs, initial_state)`: initial values of 211 'finished' flags, inputs and state. 212 """ 213 raise NotImplementedError 214 215 def step(self, time, inputs, state): 216 """Called per step of decoding (but only once for dynamic decoding). 217 218 Args: 219 time: Scalar `int32` tensor. Current step number. 220 inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time 221 step. 222 state: RNNCell state (possibly nested tuple of) tensor[s] from previous 223 time step. 224 225 Returns: 226 `(outputs, next_state, next_inputs, finished)`: `outputs` is an object 227 containing the decoder output, `next_state` is a (structure of) state 228 tensors and TensorArrays, `next_inputs` is the tensor that should be used 229 as input for the next step, `finished` is a boolean tensor telling whether 230 the sequence is complete, for each sequence in the batch. 231 """ 232 raise NotImplementedError 233 234 def finalize(self, outputs, final_state, sequence_lengths): 235 raise NotImplementedError 236 237 @property 238 def tracks_own_finished(self): 239 """Describes whether the Decoder keeps track of finished states. 240 241 Most decoders will emit a true/false `finished` value independently 242 at each time step. In this case, the `dynamic_decode` function keeps track 243 of which batch entries are already finished, and performs a logical OR to 244 insert new batches to the finished set. 245 246 Some decoders, however, shuffle batches / beams between time steps and 247 `dynamic_decode` will mix up the finished state across these entries because 248 it does not track the reshuffle across time steps. In this case, it is 249 up to the decoder to declare that it will keep track of its own finished 250 state by setting this property to `True`. 251 252 Returns: 253 Python bool. 254 """ 255 return False 256 257 # TODO(scottzhu): Add build/get_config/from_config and other layer methods. 258 259 260def _create_zero_outputs(size, dtype, batch_size): 261 """Create a zero outputs Tensor structure.""" 262 def _create(s, d): 263 return _zero_state_tensors(s, batch_size, d) 264 265 return nest.map_structure(_create, size, dtype) 266 267 268def dynamic_decode(decoder, 269 output_time_major=False, 270 impute_finished=False, 271 maximum_iterations=None, 272 parallel_iterations=32, 273 swap_memory=False, 274 scope=None, 275 **kwargs): 276 """Perform dynamic decoding with `decoder`. 277 278 Calls initialize() once and step() repeatedly on the Decoder object. 279 280 Args: 281 decoder: A `Decoder` instance. 282 output_time_major: Python boolean. Default: `False` (batch major). If 283 `True`, outputs are returned as time major tensors (this mode is faster). 284 Otherwise, outputs are returned as batch major tensors (this adds extra 285 time to the computation). 286 impute_finished: Python boolean. If `True`, then states for batch 287 entries which are marked as finished get copied through and the 288 corresponding outputs get zeroed out. This causes some slowdown at 289 each time step, but ensures that the final state and outputs have 290 the correct values and that backprop ignores time steps that were 291 marked as finished. 292 maximum_iterations: `int32` scalar, maximum allowed number of decoding 293 steps. Default is `None` (decode until the decoder is fully done). 294 parallel_iterations: Argument passed to `tf.while_loop`. 295 swap_memory: Argument passed to `tf.while_loop`. 296 scope: Optional variable scope to use. 297 **kwargs: dict, other keyword arguments for dynamic_decode. It might contain 298 arguments for `BaseDecoder` to initialize, which takes all tensor inputs 299 during call(). 300 301 Returns: 302 `(final_outputs, final_state, final_sequence_lengths)`. 303 304 Raises: 305 TypeError: if `decoder` is not an instance of `Decoder`. 306 ValueError: if `maximum_iterations` is provided but is not a scalar. 307 """ 308 if not isinstance(decoder, (Decoder, BaseDecoder)): 309 raise TypeError("Expected decoder to be type Decoder, but saw: %s" % 310 type(decoder)) 311 312 with variable_scope.variable_scope(scope, "decoder") as varscope: 313 # Determine context types. 314 ctxt = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access 315 is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None 316 in_while_loop = ( 317 control_flow_util.GetContainingWhileContext(ctxt) is not None) 318 # Properly cache variable values inside the while_loop. 319 # Don't set a caching device when running in a loop, since it is possible 320 # that train steps could be wrapped in a tf.while_loop. In that scenario 321 # caching prevents forward computations in loop iterations from re-reading 322 # the updated weights. 323 if not context.executing_eagerly() and not in_while_loop: 324 if varscope.caching_device is None: 325 varscope.set_caching_device(lambda op: op.device) 326 327 if maximum_iterations is not None: 328 maximum_iterations = ops.convert_to_tensor( 329 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 330 if maximum_iterations.get_shape().ndims != 0: 331 raise ValueError("maximum_iterations must be a scalar") 332 333 if isinstance(decoder, Decoder): 334 initial_finished, initial_inputs, initial_state = decoder.initialize() 335 else: 336 # For BaseDecoder that takes tensor inputs during call. 337 decoder_init_input = kwargs.pop("decoder_init_input", None) 338 decoder_init_kwargs = kwargs.pop("decoder_init_kwargs", {}) 339 initial_finished, initial_inputs, initial_state = decoder.initialize( 340 decoder_init_input, **decoder_init_kwargs) 341 342 zero_outputs = _create_zero_outputs(decoder.output_size, 343 decoder.output_dtype, 344 decoder.batch_size) 345 346 if is_xla and maximum_iterations is None: 347 raise ValueError("maximum_iterations is required for XLA compilation.") 348 if maximum_iterations is not None: 349 initial_finished = math_ops.logical_or( 350 initial_finished, 0 >= maximum_iterations) 351 initial_sequence_lengths = array_ops.zeros_like( 352 initial_finished, dtype=dtypes.int32) 353 initial_time = constant_op.constant(0, dtype=dtypes.int32) 354 355 def _shape(batch_size, from_shape): 356 if (not isinstance(from_shape, tensor_shape.TensorShape) or 357 from_shape.ndims == 0): 358 return None 359 else: 360 batch_size = tensor_util.constant_value( 361 ops.convert_to_tensor( 362 batch_size, name="batch_size")) 363 return tensor_shape.TensorShape([batch_size]).concatenate(from_shape) 364 365 dynamic_size = maximum_iterations is None or not is_xla 366 367 def _create_ta(s, d): 368 return tensor_array_ops.TensorArray( 369 dtype=d, 370 size=0 if dynamic_size else maximum_iterations, 371 dynamic_size=dynamic_size, 372 element_shape=_shape(decoder.batch_size, s)) 373 374 initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, 375 decoder.output_dtype) 376 377 def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, 378 finished, unused_sequence_lengths): 379 return math_ops.logical_not(math_ops.reduce_all(finished)) 380 381 def body(time, outputs_ta, state, inputs, finished, sequence_lengths): 382 """Internal while_loop body. 383 384 Args: 385 time: scalar int32 tensor. 386 outputs_ta: structure of TensorArray. 387 state: (structure of) state tensors and TensorArrays. 388 inputs: (structure of) input tensors. 389 finished: bool tensor (keeping track of what's finished). 390 sequence_lengths: int32 tensor (keeping track of time of finish). 391 392 Returns: 393 `(time + 1, outputs_ta, next_state, next_inputs, next_finished, 394 next_sequence_lengths)`. 395 ``` 396 """ 397 (next_outputs, decoder_state, next_inputs, 398 decoder_finished) = decoder.step(time, inputs, state) 399 if decoder.tracks_own_finished: 400 next_finished = decoder_finished 401 else: 402 next_finished = math_ops.logical_or(decoder_finished, finished) 403 next_sequence_lengths = array_ops.where( 404 math_ops.logical_not(finished), 405 array_ops.fill(array_ops.shape(sequence_lengths), time + 1), 406 sequence_lengths) 407 408 nest.assert_same_structure(state, decoder_state) 409 nest.assert_same_structure(outputs_ta, next_outputs) 410 nest.assert_same_structure(inputs, next_inputs) 411 412 # Zero out output values past finish 413 if impute_finished: 414 emit = nest.map_structure( 415 lambda out, zero: array_ops.where(finished, zero, out), 416 next_outputs, 417 zero_outputs) 418 else: 419 emit = next_outputs 420 421 # Copy through states past finish 422 def _maybe_copy_state(new, cur): 423 # TensorArrays and scalar states get passed through. 424 if isinstance(cur, tensor_array_ops.TensorArray): 425 pass_through = True 426 else: 427 new.set_shape(cur.shape) 428 pass_through = (new.shape.ndims == 0) 429 return new if pass_through else array_ops.where(finished, cur, new) 430 431 if impute_finished: 432 next_state = nest.map_structure( 433 _maybe_copy_state, decoder_state, state) 434 else: 435 next_state = decoder_state 436 437 outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), 438 outputs_ta, emit) 439 return (time + 1, outputs_ta, next_state, next_inputs, next_finished, 440 next_sequence_lengths) 441 442 res = control_flow_ops.while_loop( 443 condition, 444 body, 445 loop_vars=( 446 initial_time, 447 initial_outputs_ta, 448 initial_state, 449 initial_inputs, 450 initial_finished, 451 initial_sequence_lengths, 452 ), 453 parallel_iterations=parallel_iterations, 454 maximum_iterations=maximum_iterations, 455 swap_memory=swap_memory) 456 457 final_outputs_ta = res[1] 458 final_state = res[2] 459 final_sequence_lengths = res[5] 460 461 final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) 462 463 try: 464 final_outputs, final_state = decoder.finalize( 465 final_outputs, final_state, final_sequence_lengths) 466 except NotImplementedError: 467 pass 468 469 if not output_time_major: 470 final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) 471 472 return final_outputs, final_state, final_sequence_lengths 473