• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (C) 2017 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17# LSTM Test: No Cifg, No Peephole, No Projection, and No Clipping.
18
19model = Model()
20
21n_batch = 1
22n_input = 2
23# n_cell and n_output have the same size when there is no projection.
24n_cell = 4
25n_output = 4
26
27input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_input))
28
29input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
30input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
31input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
32input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
33
34recurrent_to_input_weights = Input("recurrent_to_intput_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
35recurrent_to_forget_weights = Input("recurrent_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
36recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
37recurrent_to_output_weights = Input("recurrent_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
38
39cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32", "{0}")
40cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32", "{0}")
41cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32", "{0}")
42
43input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
44forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
45cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
46output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
47
48projection_weights = Input("projection_weights", "TENSOR_FLOAT32", "{0,0}")
49projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}")
50
51output_state_in = Input("output_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
52cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
53
54activation_param = Int32Scalar("activation_param", 4)  # Tanh
55cell_clip_param = Float32Scalar("cell_clip_param", 0.)
56proj_clip_param = Float32Scalar("proj_clip_param", 0.)
57
58scratch_buffer = IgnoredOutput("scratch_buffer", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, (n_cell * 4)))
59output_state_out = Output("output_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
60cell_state_out = Output("cell_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
61output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
62
63model = model.Operation("LSTM",
64                        input,
65
66                        input_to_input_weights,
67                        input_to_forget_weights,
68                        input_to_cell_weights,
69                        input_to_output_weights,
70
71                        recurrent_to_input_weights,
72                        recurrent_to_forget_weights,
73                        recurrent_to_cell_weights,
74                        recurrent_to_output_weights,
75
76                        cell_to_input_weights,
77                        cell_to_forget_weights,
78                        cell_to_output_weights,
79
80                        input_gate_bias,
81                        forget_gate_bias,
82                        cell_gate_bias,
83                        output_gate_bias,
84
85                        projection_weights,
86                        projection_bias,
87
88                        output_state_in,
89                        cell_state_in,
90
91                        activation_param,
92                        cell_clip_param,
93                        proj_clip_param
94).To([scratch_buffer, output_state_out, cell_state_out, output])
95
96# Example 1. Input in operand 0,
97input0 = {input_to_input_weights:  [-0.45018822, -0.02338299, -0.0870589, -0.34550029, 0.04266912, -0.15680569, -0.34856534, 0.43890524],
98          input_to_forget_weights: [0.09701663, 0.20334584, -0.50592935, -0.31343272, -0.40032279, 0.44781327, 0.01387155, -0.35593212],
99          input_to_cell_weights:   [-0.50013041, 0.1370284, 0.11810488, 0.2013163, -0.20583314, 0.44344562, 0.22077113, -0.29909778],
100          input_to_output_weights: [-0.25065863, -0.28290087, 0.04613829, 0.40525138, 0.44272184, 0.03897077, -0.1556896, 0.19487578],
101
102          input_gate_bias:  [0.,0.,0.,0.],
103          forget_gate_bias: [1.,1.,1.,1.],
104          cell_gate_bias:   [0.,0.,0.,0.],
105          output_gate_bias: [0.,0.,0.,0.],
106
107          recurrent_to_input_weights: [
108              -0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
109            -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
110            -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296],
111
112          recurrent_to_cell_weights: [
113              -0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
114            -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
115            -0.46367589, 0.26016325, -0.03894562, -0.16368064],
116
117          recurrent_to_forget_weights: [
118              -0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
119            -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
120            0.28053468, 0.01560611, -0.20127171, -0.01140004],
121
122          recurrent_to_output_weights: [
123              0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
124              0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
125              -0.51818722, -0.15390486, 0.0468148, 0.39922136],
126
127          cell_to_input_weights: [],
128          cell_to_forget_weights: [],
129          cell_to_output_weights: [],
130
131          projection_weights: [],
132          projection_bias: [],
133}
134
135test_input = [3., 4.]
136output_state = [-0.0297319, 0.122947, 0.208851, -0.153588]
137cell_state = [-0.145439, 0.157475, 0.293663, -0.277353,]
138golden_output = [-0.03716109, 0.12507336, 0.41193449,  -0.20860538]
139output0 = {
140    scratch_buffer: [ 0 for x in range(n_batch * n_cell * 4) ],
141    cell_state_out: [ -0.287121, 0.148115, 0.556837, -0.388276 ],
142    output_state_out: [ -0.0371611, 0.125073, 0.411934, -0.208605 ],
143    output: golden_output
144}
145input0[input] = test_input
146input0[output_state_in] = output_state
147input0[cell_state_in] = cell_state
148Example((input0, output0))
149