1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""TensorFlow Ops for Sequence to Sequence models (deprecated). 17 18This module and all its submodules are deprecated. See 19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 20for migration instructions. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27from tensorflow.contrib import rnn 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import nn 33from tensorflow.python.ops import variable_scope as vs 34from tensorflow.python.util.deprecation import deprecated 35 36 37@deprecated(None, 'Please use tf.nn/tf.layers directly.') 38def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): 39 """Returns predictions and loss for sequence of predictions. 40 41 Args: 42 decoding: List of Tensors with predictions. 43 labels: List of Tensors with labels. 44 sampling_decoding: Optional, List of Tensor with predictions to be used 45 in sampling. E.g. they shouldn't have dependncy on outputs. 46 If not provided, decoding is used. 47 name: Operation name. 48 49 Returns: 50 Predictions and losses tensors. 51 """ 52 with ops.name_scope(name, "sequence_classifier", [decoding, labels]): 53 predictions, xent_list = [], [] 54 for i, pred in enumerate(decoding): 55 xent_list.append(nn.softmax_cross_entropy_with_logits( 56 labels=labels[i], logits=pred, 57 name="sequence_loss/xent_raw{0}".format(i))) 58 if sampling_decoding: 59 predictions.append(nn.softmax(sampling_decoding[i])) 60 else: 61 predictions.append(nn.softmax(pred)) 62 xent = math_ops.add_n(xent_list, name="sequence_loss/xent") 63 loss = math_ops.reduce_sum(xent, name="sequence_loss") 64 return array_ops.stack(predictions, axis=1), loss 65 66 67@deprecated(None, 'Please use tf.nn/tf.layers directly.') 68def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): 69 """Processes inputs for Sequence to Sequence models. 70 71 Args: 72 x: Input Tensor [batch_size, input_length, embed_dim]. 73 y: Output Tensor [batch_size, output_length, embed_dim]. 74 input_length: length of input x. 75 output_length: length of output y. 76 sentinel: optional first input to decoder and final output expected. 77 If sentinel is not provided, zeros are used. Due to fact that y is not 78 available in sampling time, shape of sentinel will be inferred from x. 79 name: Operation name. 80 81 Returns: 82 Encoder input from x, and decoder inputs and outputs from y. 83 """ 84 with ops.name_scope(name, "seq2seq_inputs", [x, y]): 85 in_x = array_ops.unstack(x, axis=1) 86 y = array_ops.unstack(y, axis=1) 87 if not sentinel: 88 # Set to zeros of shape of y[0], using x for batch size. 89 sentinel_shape = array_ops.stack( 90 [array_ops.shape(x)[0], y[0].get_shape()[1]]) 91 sentinel = array_ops.zeros(sentinel_shape) 92 sentinel.set_shape(y[0].get_shape()) 93 in_y = [sentinel] + y 94 out_y = y + [sentinel] 95 return in_x, in_y, out_y 96 97 98@deprecated(None, 'Please use tf.nn/tf.layers directly.') 99def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): 100 """RNN Decoder that creates training and sampling sub-graphs. 101 102 Args: 103 decoder_inputs: Inputs for decoder, list of tensors. 104 This is used only in training sub-graph. 105 initial_state: Initial state for the decoder. 106 cell: RNN cell to use for decoder. 107 scope: Scope to use, if None new will be produced. 108 109 Returns: 110 List of tensors for outputs and states for training and sampling sub-graphs. 111 """ 112 with vs.variable_scope(scope or "dnn_decoder"): 113 states, sampling_states = [initial_state], [initial_state] 114 outputs, sampling_outputs = [], [] 115 with ops.name_scope("training", values=[decoder_inputs, initial_state]): 116 for i, inp in enumerate(decoder_inputs): 117 if i > 0: 118 vs.get_variable_scope().reuse_variables() 119 output, new_state = cell(inp, states[-1]) 120 outputs.append(output) 121 states.append(new_state) 122 with ops.name_scope("sampling", values=[initial_state]): 123 for i, _ in enumerate(decoder_inputs): 124 if i == 0: 125 sampling_outputs.append(outputs[i]) 126 sampling_states.append(states[i]) 127 else: 128 sampling_output, sampling_state = cell(sampling_outputs[-1], 129 sampling_states[-1]) 130 sampling_outputs.append(sampling_output) 131 sampling_states.append(sampling_state) 132 return outputs, states, sampling_outputs, sampling_states 133 134 135@deprecated(None, 'Please use tf.nn/tf.layers directly.') 136def rnn_seq2seq(encoder_inputs, 137 decoder_inputs, 138 encoder_cell, 139 decoder_cell=None, 140 dtype=dtypes.float32, 141 scope=None): 142 """RNN Sequence to Sequence model. 143 144 Args: 145 encoder_inputs: List of tensors, inputs for encoder. 146 decoder_inputs: List of tensors, inputs for decoder. 147 encoder_cell: RNN cell to use for encoder. 148 decoder_cell: RNN cell to use for decoder, if None encoder_cell is used. 149 dtype: Type to initialize encoder state with. 150 scope: Scope to use, if None new will be produced. 151 152 Returns: 153 List of tensors for outputs and states for training and sampling sub-graphs. 154 """ 155 with vs.variable_scope(scope or "rnn_seq2seq"): 156 _, last_enc_state = rnn.static_rnn( 157 encoder_cell, encoder_inputs, dtype=dtype) 158 return rnn_decoder(decoder_inputs, last_enc_state, decoder_cell or 159 encoder_cell) 160