• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (C) 2017 The Android Open Source Project
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#      http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15#
16
17# LSTM Test, With Cifg, With Peephole, No Projection, No Clipping.
18
19model = Model()
20
21n_batch = 1
22n_input = 2
23# n_cell and n_output have the same size when there is no projection.
24n_cell = 4
25n_output = 4
26
27input = Input("input", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_input))
28
29input_to_input_weights = Input("input_to_input_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
30input_to_forget_weights = Input("input_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
31input_to_cell_weights = Input("input_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
32input_to_output_weights = Input("input_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_input))
33
34recurrent_to_input_weights = Input("recurrent_to_intput_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
35recurrent_to_forget_weights = Input("recurrent_to_forget_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
36recurrent_to_cell_weights = Input("recurrent_to_cell_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
37recurrent_to_output_weights = Input("recurrent_to_output_weights", "TENSOR_FLOAT32", "{%d, %d}" % (n_cell, n_output))
38
39cell_to_input_weights = Input("cell_to_input_weights", "TENSOR_FLOAT32", "{0}")
40cell_to_forget_weights = Input("cell_to_forget_weights", "TENSOR_FLOAT32", "{%d}" % (n_cell))
41cell_to_output_weights = Input("cell_to_output_weights", "TENSOR_FLOAT32", "{%d}" % (n_cell))
42
43input_gate_bias = Input("input_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
44forget_gate_bias = Input("forget_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
45cell_gate_bias = Input("cell_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
46output_gate_bias = Input("output_gate_bias", "TENSOR_FLOAT32", "{%d}"%(n_cell))
47
48projection_weights = Input("projection_weights", "TENSOR_FLOAT32", "{0,0}")
49projection_bias = Input("projection_bias", "TENSOR_FLOAT32", "{0}")
50
51output_state_in = Input("output_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
52cell_state_in = Input("cell_state_in", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
53
54activation_param = Int32Scalar("activation_param", 4)  # Tanh
55cell_clip_param = Float32Scalar("cell_clip_param", 0.)
56proj_clip_param = Float32Scalar("proj_clip_param", 0.)
57
58scratch_buffer = IgnoredOutput("scratch_buffer", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell * 3))
59output_state_out = Output("output_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
60cell_state_out = Output("cell_state_out", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_cell))
61output = Output("output", "TENSOR_FLOAT32", "{%d, %d}" % (n_batch, n_output))
62
63model = model.Operation("LSTM",
64                        input,
65
66                        input_to_input_weights,
67                        input_to_forget_weights,
68                        input_to_cell_weights,
69                        input_to_output_weights,
70
71                        recurrent_to_input_weights,
72                        recurrent_to_forget_weights,
73                        recurrent_to_cell_weights,
74                        recurrent_to_output_weights,
75
76                        cell_to_input_weights,
77                        cell_to_forget_weights,
78                        cell_to_output_weights,
79
80                        input_gate_bias,
81                        forget_gate_bias,
82                        cell_gate_bias,
83                        output_gate_bias,
84
85                        projection_weights,
86                        projection_bias,
87
88                        output_state_in,
89                        cell_state_in,
90
91                        activation_param,
92                        cell_clip_param,
93                        proj_clip_param
94).To([scratch_buffer, output_state_out, cell_state_out, output])
95
96input0 = {input_to_input_weights:[],
97          input_to_cell_weights: [-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, 0.48944736, -0.38535351, -0.17212132],
98          input_to_forget_weights: [-0.55291498, -0.42866567, 0.13056988, -0.3633365, -0.22755712, 0.28253698, 0.24407166, 0.33826375],
99          input_to_output_weights: [0.10725588, -0.02335852, -0.55932593, -0.09426838, -0.44257352, 0.54939759, 0.01533556, 0.42751634],
100
101          input_gate_bias:  [],
102          forget_gate_bias: [1.,1.,1.,1.],
103          cell_gate_bias:   [0.,0.,0.,0.],
104          output_gate_bias: [0.,0.,0.,0.],
105
106          recurrent_to_input_weights: [],
107          recurrent_to_cell_weights: [
108              0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
109              0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
110              0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
111              0.21193194],
112
113          recurrent_to_forget_weights: [
114              -0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
115            0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
116            -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349],
117
118          recurrent_to_output_weights: [
119              0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
120              -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
121              0.50248802, 0.26114327, -0.43736315, 0.33149987],
122
123          cell_to_input_weights: [],
124          cell_to_forget_weights: [0.47485286, -0.51955009, -0.24458408, 0.31544167],
125          cell_to_output_weights: [-0.17135078, 0.82760304, 0.85573703, -0.77109635],
126
127          projection_weights: [],
128          projection_bias: [],
129}
130
131output0 = {
132    scratch_buffer: [ 0 for x in range(n_batch * n_cell * 3) ],
133    cell_state_out: [ -0.760444, -0.0180416, 0.182264, -0.0649371 ],
134    output_state_out: [ -0.364445, -0.00352185, 0.128866, -0.0516365 ],
135}
136
137input0[input] = [2., 3.]
138input0[output_state_in] = [ 0 for _ in range(n_batch * n_output) ]
139input0[cell_state_in] = [ 0 for _ in range(n_batch * n_cell) ]
140output0[output] = [-0.36444446, -0.00352185, 0.12886585, -0.05163646]
141
142Example((input0, output0))
143