1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Functional RNN.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23 24from tensorflow.contrib.recurrent.python.ops import functional_rnn 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gradients_impl 30from tensorflow.python.ops import rnn as rnn_lib 31from tensorflow.python.ops import rnn_cell_impl 32from tensorflow.python.ops import variables 33import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 34import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 35from tensorflow.python.platform import test as test_lib 36from tensorflow.python.platform import tf_logging as logging 37 38 39def _CreateStackedLstmCell(*cell_sizes): 40 subcells = [rnn_cell_impl.LSTMCell(cell_size) for cell_size in cell_sizes] 41 return rnn_cell_impl.MultiRNNCell(subcells) 42 43 44class FunctionalRnnTest(test_util.TensorFlowTestCase): 45 46 _BATCH_SIZE = 3 47 _TOTAL_TIME = 5 48 _INPUT_SIZE = 11 49 _NUM_UNITS = 7 50 51 # Set this to some output if you want to use it. 52 _LSTM_GRAPH_DEF_FILEPATH = None 53 54 _CELLDEFS = { 55 'gru': (rnn_cell_impl.GRUCell, [_NUM_UNITS]), 56 'lstm': (rnn_cell_impl.LSTMCell, [_NUM_UNITS]), 57 'stacked_lstm': (_CreateStackedLstmCell, [_NUM_UNITS] * 3) 58 } 59 60 def _CreateCell(self, celldef_name): 61 func, args = self._CELLDEFS[celldef_name] 62 return func(*args) 63 64 def _CreateInputs(self, time_major=False): 65 if time_major: 66 inputs = np.random.random([ 67 FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE, 68 FunctionalRnnTest._INPUT_SIZE 69 ]) 70 else: 71 inputs = np.random.random([ 72 FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME, 73 FunctionalRnnTest._INPUT_SIZE 74 ]) 75 # Always leave one time slot empty, to check max_length behavior. 76 sequence_length = np.random.randint( 77 0, high=FunctionalRnnTest._TOTAL_TIME - 1, 78 size=FunctionalRnnTest._BATCH_SIZE, 79 dtype=np.int) 80 return (inputs, sequence_length) 81 82 def _CreateSymmetricInputs(self): 83 # total time = batch size 84 inputs = np.zeros( 85 (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE, 86 FunctionalRnnTest._INPUT_SIZE)) 87 for i in range(FunctionalRnnTest._BATCH_SIZE): 88 for j in range(i, FunctionalRnnTest._BATCH_SIZE): 89 inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE]) 90 inputs[j][i] = inputs[i][j] 91 92 # Always leave one time slot empty, to check max_length behavior. 93 sequence_length = np.random.randint( 94 0, 95 high=FunctionalRnnTest._BATCH_SIZE - 1, 96 size=FunctionalRnnTest._BATCH_SIZE, 97 dtype=np.int) 98 return (inputs, sequence_length) 99 100 def _CreateRnnGraph(self, 101 create_rnn_computation_func, 102 cell, 103 tf_inputs, 104 tf_sequence_length, 105 is_bidirectional, 106 initial_state=None, 107 time_major=None, 108 scope=None): 109 if is_bidirectional: 110 tf_result = create_rnn_computation_func( 111 cell_fw=cell, 112 cell_bw=cell, 113 inputs=tf_inputs, 114 sequence_length=tf_sequence_length, 115 dtype=dtypes.float32, 116 time_major=time_major, 117 scope=scope) 118 else: 119 tf_result = create_rnn_computation_func( 120 cell=cell, 121 inputs=tf_inputs, 122 sequence_length=tf_sequence_length, 123 initial_state=initial_state, 124 dtype=dtypes.float32, 125 time_major=time_major, 126 scope=scope) 127 grad = gradients_impl.gradients(tf_result, variables.trainable_variables()) 128 return {'inference': tf_result, 'grad': grad} 129 130 def _MaybeResetVariables(self, variable_cache, sess, var_list): 131 """Possibly resets the variables to a previously seen value.""" 132 reset_ops = [] 133 fetches = [] 134 for var in var_list: 135 if var.name in variable_cache: 136 reset_ops += [var.assign(variable_cache[var.name])] 137 else: 138 fetches += [(var.name, var)] 139 if reset_ops: 140 sess.run(reset_ops) 141 if fetches: 142 val = sess.run(dict(fetches)) 143 for n, v in val.items(): 144 assert n not in variable_cache 145 variable_cache[n] = v 146 147 def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache, 148 is_dynamic, time_major=None, is_bidirectional=False): 149 with ops.Graph().as_default() as graph: 150 tf_inputs = array_ops.placeholder( 151 dtypes.float32, shape=numpy_inputs.shape) 152 tf_slen = array_ops.placeholder(dtypes.int32) 153 feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen} 154 cell = self._CreateCell(cell_name) 155 if is_dynamic: 156 if is_bidirectional: 157 fn = rnn_lib.bidirectional_dynamic_rnn 158 else: 159 fn = rnn_lib.dynamic_rnn 160 else: 161 if is_bidirectional: 162 fn = functional_rnn.bidirectional_functional_rnn 163 else: 164 fn = functional_rnn.functional_rnn 165 166 fetches = self._CreateRnnGraph( 167 fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major) 168 with self.session(graph=graph) as sess: 169 sess.run(variables.global_variables_initializer()) 170 # Note that cell.trainable_variables it not always set. 171 self._MaybeResetVariables(variable_cache, sess, 172 variables.trainable_variables()) 173 val = sess.run(fetches, feed_dict=feeds) 174 graph_def = graph.as_graph_def() 175 return graph_def, val 176 177 def testRunLstm(self): 178 """Runs a simple LSTM. Does not check output.""" 179 np_inputs, np_slen = self._CreateInputs() 180 var_cache = {} 181 graphdef, _ = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False) 182 logging.info('graphdef: %s', graphdef) 183 if self._LSTM_GRAPH_DEF_FILEPATH: 184 with open(self._LSTM_GRAPH_DEF_FILEPATH, 'w') as f: 185 f.write(str(graphdef)) 186 187 def testLstm(self): 188 """Checks an LSTM against the reference implementation.""" 189 np_inputs, np_slen = self._CreateInputs() 190 var_cache = {} 191 _, func_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, False) 192 _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'lstm', var_cache, True) 193 self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) 194 self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) 195 196 def testGru(self): 197 """Checks a GRU cell against the reference implementation.""" 198 np_inputs, np_slen = self._CreateInputs() 199 var_cache = {} 200 _, func_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, False) 201 _, dyn_rnn = self._RunRnn(np_inputs, np_slen, 'gru', var_cache, True) 202 self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) 203 self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) 204 205 def testStackedLstm(self): 206 """Checks a stacked LSTM cell against the reference implementation.""" 207 np_inputs, np_slen = self._CreateInputs() 208 var_cache = {} 209 args = [np_inputs, np_slen, 'stacked_lstm', var_cache] 210 _, func_rnn = self._RunRnn(*(args + [False])) 211 _, dyn_rnn = self._RunRnn(*(args + [True])) 212 self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) 213 self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) 214 215 def testLstmWithTimeMajorInputs(self): 216 """Checks an LSTM against the reference implementation, with time_major.""" 217 time_major = True 218 np_inputs, np_slen = self._CreateInputs(time_major=True) 219 var_cache = {} 220 args = [np_inputs, np_slen, 'lstm', var_cache] 221 _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major) 222 _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major) 223 self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) 224 self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) 225 226 def testBidirectionalLstmWithTimeMajorInputs(self): 227 """Checks a bi-directional LSTM with time-major inputs.""" 228 time_major = True 229 np_inputs, np_slen = self._CreateInputs(time_major) 230 var_cache = {} 231 args = [np_inputs, np_slen, 'lstm', var_cache] 232 _, func_rnn = self._RunRnn( 233 *(args + [False]), time_major=time_major, is_bidirectional=True) 234 _, dyn_rnn = self._RunRnn( 235 *(args + [True]), time_major=time_major, is_bidirectional=True) 236 self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) 237 # TODO(b/112170761): comment out this line after the bug is fixed. 238 # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) 239 240 def testBidirectionalLstm(self): 241 """Checks time-major and batch-major rnn produce consistent results.""" 242 time_major_inputs, np_slen = self._CreateInputs(True) 243 batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2]) 244 var_cache = {} 245 args = [np_slen, 'lstm', var_cache, False] 246 _, time_major_rnn = self._RunRnn( 247 *([time_major_inputs] + args), time_major=True, is_bidirectional=True) 248 _, batch_major_rnn = self._RunRnn( 249 *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True) 250 # Convert the batch-major outputs to be time-major before the comparasion. 251 outputs, state = batch_major_rnn['inference'] 252 outputs = [np.transpose(x, [1, 0, 2]) for x in outputs] 253 batch_major_rnn['inference'] = [outputs, state] 254 self.assertAllClose(time_major_rnn['inference'], 255 batch_major_rnn['inference']) 256 self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad']) 257 258 def testBidirectionalLstmWithSymmetricInputs(self): 259 """Checks a bi-directional LSTM with symmetric inputs. 260 261 time-major and batch-major rnn produce the same result with symmetric 262 inputs. 263 """ 264 np_inputs, np_slen = self._CreateSymmetricInputs() 265 var_cache = {} 266 args = [np_inputs, np_slen, 'lstm', var_cache] 267 _, time_major_func_rnn = self._RunRnn( 268 *(args + [False]), time_major=True, is_bidirectional=True) 269 _, batch_major_func_rnn = self._RunRnn( 270 *(args + [False]), time_major=False, is_bidirectional=True) 271 _, time_major_dyn_rnn = self._RunRnn( 272 *(args + [True]), time_major=True, is_bidirectional=True) 273 _, batch_major_dyn_rnn = self._RunRnn( 274 *(args + [True]), time_major=False, is_bidirectional=True) 275 self.assertAllClose(time_major_func_rnn['inference'], 276 batch_major_func_rnn['inference']) 277 self.assertAllClose(time_major_func_rnn['grad'], 278 batch_major_func_rnn['grad']) 279 self.assertAllClose(time_major_dyn_rnn['inference'], 280 batch_major_dyn_rnn['inference']) 281 self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad']) 282 self.assertAllClose(time_major_func_rnn['inference'], 283 batch_major_dyn_rnn['inference']) 284 self.assertAllClose(time_major_func_rnn['grad'], 285 batch_major_dyn_rnn['grad']) 286 287 288if __name__ == '__main__': 289 test_lib.main() 290