1# 2# Copyright (C) 2017 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17# LSTM Test: No Cifg, No Peephole, No Projection, and No Clipping. 18 19model = Model() 20 21n_batch = 1 22n_input = 2 23# n_cell and n_output have the same size when there is no projection. 24n_cell = 4 25n_output = 4 26 27input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_input)) 28 29input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 30input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 31input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 32input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 33 34recurrent_to_input_weights = Input("recurrent_to_intput_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 35recurrent_to_forget_weights = Input("recurrent_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 36recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 37recurrent_to_output_weights = Input("recurrent_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 38 39cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32", "{0}") 40cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32", "{0}") 41cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32", "{0}") 42 43input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 44forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 45cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 46output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 47 48projection_weights = Input("projection_weights", "TENSOR_FLOAT32", "{0,0}") 49projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}") 50 51output_state_in = Input("output_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output)) 52cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell)) 53 54activation_param = Int32Scalar("activation_param", 4) # Tanh 55cell_clip_param = Float32Scalar("cell_clip_param", 0.) 56proj_clip_param = Float32Scalar("proj_clip_param", 0.) 57 58scratch_buffer = IgnoredOutput("scratch_buffer", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, (n_cell * 4))) 59output_state_out = Output("output_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output)) 60cell_state_out = Output("cell_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell)) 61output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output)) 62 63model = model.Operation("LSTM", 64 input, 65 66 input_to_input_weights, 67 input_to_forget_weights, 68 input_to_cell_weights, 69 input_to_output_weights, 70 71 recurrent_to_input_weights, 72 recurrent_to_forget_weights, 73 recurrent_to_cell_weights, 74 recurrent_to_output_weights, 75 76 cell_to_input_weights, 77 cell_to_forget_weights, 78 cell_to_output_weights, 79 80 input_gate_bias, 81 forget_gate_bias, 82 cell_gate_bias, 83 output_gate_bias, 84 85 projection_weights, 86 projection_bias, 87 88 output_state_in, 89 cell_state_in, 90 91 activation_param, 92 cell_clip_param, 93 proj_clip_param 94).To([scratch_buffer, output_state_out, cell_state_out, output]) 95 96# Example 1. Input in operand 0, 97input0 = {input_to_input_weights: [-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, -0.34856534, 0.43890524], 98 input_to_forget_weights: [0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279, 0.44781327, 0.01387155, -0.35593212], 99 input_to_cell_weights: [-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314, 0.44344562, 0.22077113, -0.29909778], 100 input_to_output_weights: [-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184, 0.03897077, -0.1556896, 0.19487578], 101 102 input_gate_bias: [0.,0.,0.,0.], 103 forget_gate_bias: [1.,1.,1.,1.], 104 cell_gate_bias: [0.,0.,0.,0.], 105 output_gate_bias: [0.,0.,0.,0.], 106 107 recurrent_to_input_weights: [ 108 -0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, 109 -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, 110 -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296], 111 112 recurrent_to_cell_weights: [ 113 -0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, 114 -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, 115 -0.46367589, 0.26016325, -0.03894562, -0.16368064], 116 117 recurrent_to_forget_weights: [ 118 -0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, 119 -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, 120 0.28053468, 0.01560611, -0.20127171, -0.01140004], 121 122 recurrent_to_output_weights: [ 123 0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, 124 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, 125 -0.51818722, -0.15390486, 0.0468148, 0.39922136], 126 127 cell_to_input_weights: [], 128 cell_to_forget_weights: [], 129 cell_to_output_weights: [], 130 131 projection_weights: [], 132 projection_bias: [], 133} 134 135test_input = [2., 3.] 136output_state = [0, 0, 0, 0] 137cell_state = [0, 0, 0, 0] 138golden_output = [-0.02973187, 0.1229473, 0.20885126, -0.15358765,] 139output0 = { 140 scratch_buffer: [ 0 for x in range(n_batch * n_cell * 4) ], 141 cell_state_out: [ -0.145439, 0.157475, 0.293663, -0.277353 ], 142 output_state_out: [ -0.0297319, 0.122947, 0.208851, -0.153588 ], 143 output: golden_output 144} 145input0[input] = test_input 146input0[output_state_in] = output_state 147input0[cell_state_in] = cell_state 148Example((input0, output0)) 149