1# 2# Copyright (C) 2018 The Android Open Source Project 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# 16 17# LSTM Test, With Cifg, With Peephole, No Projection, No Clipping. 18 19model = Model() 20 21n_batch = 1 22n_input = 2 23# n_cell and n_output have the same size when there is no projection. 24n_cell = 4 25n_output = 4 26 27input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_input)) 28 29input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 30input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 31input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 32input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input)) 33 34recurrent_to_input_weights = Input("recurrent_to_intput_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 35recurrent_to_forget_weights = Input("recurrent_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 36recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 37recurrent_to_output_weights = Input("recurrent_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output)) 38 39cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32", "{0}") 40cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32", "{%d}" % (n_cell)) 41cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32", "{%d}" % (n_cell)) 42 43input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 44forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 45cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 46output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell)) 47 48projection_weights = Input("projection_weights", "TENSOR_FLOAT32", "{0,0}") 49projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}") 50 51output_state_in = Input("output_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output)) 52cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell)) 53 54activation_param = Int32Scalar("activation_param", 4) # Tanh 55cell_clip_param = Float32Scalar("cell_clip_param", 0.) 56proj_clip_param = Float32Scalar("proj_clip_param", 0.) 57 58scratch_buffer = IgnoredOutput("scratch_buffer", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell * 3)) 59output_state_out = Output("output_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output)) 60cell_state_out = Output("cell_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell)) 61output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output)) 62 63model = model.Operation("LSTM", 64 input, 65 66 input_to_input_weights, 67 input_to_forget_weights, 68 input_to_cell_weights, 69 input_to_output_weights, 70 71 recurrent_to_input_weights, 72 recurrent_to_forget_weights, 73 recurrent_to_cell_weights, 74 recurrent_to_output_weights, 75 76 cell_to_input_weights, 77 cell_to_forget_weights, 78 cell_to_output_weights, 79 80 input_gate_bias, 81 forget_gate_bias, 82 cell_gate_bias, 83 output_gate_bias, 84 85 projection_weights, 86 projection_bias, 87 88 output_state_in, 89 cell_state_in, 90 91 activation_param, 92 cell_clip_param, 93 proj_clip_param 94).To([scratch_buffer, output_state_out, cell_state_out, output]) 95model = model.RelaxedExecution(True) 96 97input0 = {input_to_input_weights:[], 98 input_to_cell_weights: [-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, 0.48944736, -0.38535351, -0.17212132], 99 input_to_forget_weights: [-0.55291498, -0.42866567, 0.13056988, -0.3633365, -0.22755712, 0.28253698, 0.24407166, 0.33826375], 100 input_to_output_weights: [0.10725588, -0.02335852, -0.55932593, -0.09426838, -0.44257352, 0.54939759, 0.01533556, 0.42751634], 101 102 input_gate_bias: [], 103 forget_gate_bias: [1.,1.,1.,1.], 104 cell_gate_bias: [0.,0.,0.,0.], 105 output_gate_bias: [0.,0.,0.,0.], 106 107 recurrent_to_input_weights: [], 108 recurrent_to_cell_weights: [ 109 0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, 110 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, 111 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, 112 0.21193194], 113 114 recurrent_to_forget_weights: [ 115 -0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, 116 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, 117 -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349], 118 119 recurrent_to_output_weights: [ 120 0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, 121 -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, 122 0.50248802, 0.26114327, -0.43736315, 0.33149987], 123 124 cell_to_input_weights: [], 125 cell_to_forget_weights: [0.47485286, -0.51955009, -0.24458408, 0.31544167], 126 cell_to_output_weights: [-0.17135078, 0.82760304, 0.85573703, -0.77109635], 127 128 projection_weights: [], 129 projection_bias: [], 130} 131 132output0 = { 133 scratch_buffer: [ 0 for x in range(n_batch * n_cell * 3) ], 134 cell_state_out: [ -0.760444, -0.0180416, 0.182264, -0.0649371 ], 135 output_state_out: [ -0.364445, -0.00352185, 0.128866, -0.0516365 ], 136} 137 138input0[input] = [2., 3.] 139input0[output_state_in] = [ 0 for _ in range(n_batch * n_output) ] 140input0[cell_state_in] = [ 0 for _ in range(n_batch * n_cell) ] 141output0[output] = [-0.36444446, -0.00352185, 0.12886585, -0.05163646] 142 143Example((input0, output0)) 144