• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
25 REGISTER_OP("GRUBlockCell")
26     .Attr("T: {float}")
27     .Input("x: T")
28     .Input("h_prev: T")
29     .Input("w_ru: T")
30     .Input("w_c: T")
31     .Input("b_ru: T")
32     .Input("b_c: T")
33     .Output("r: T")
34     .Output("u: T")
35     .Output("c: T")
36     .Output("h: T")
__anon4a6db9170102(InferenceContext* c) 37     .SetShapeFn([](InferenceContext* c) {
38       ShapeHandle x, h_prev;
39       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
40       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
41 
42       DimensionHandle batch_size = c->Dim(x, 0);
43       DimensionHandle cell_size = c->Dim(h_prev, 1);
44       ShapeHandle output = c->Matrix(batch_size, cell_size);
45       for (int i = 0; i < 4; ++i) {
46         c->set_output(i, output);
47       }
48       return OkStatus();
49     });
50 
51 REGISTER_OP("GRUBlockCellGrad")
52     .Attr("T: {float}")
53     .Input("x: T")
54     .Input("h_prev: T")
55     .Input("w_ru: T")
56     .Input("w_c: T")
57     .Input("b_ru: T")
58     .Input("b_c: T")
59     .Input("r: T")
60     .Input("u: T")
61     .Input("c: T")
62     .Input("d_h: T")
63     .Output("d_x: T")
64     .Output("d_h_prev: T")
65     .Output("d_c_bar: T")
66     .Output("d_r_bar_u_bar: T")
__anon4a6db9170202(InferenceContext* c) 67     .SetShapeFn([](InferenceContext* c) {
68       ShapeHandle x, h_prev, w_ru;
69       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
70       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
71       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru));
72 
73       DimensionHandle batch_size = c->Dim(x, 0);
74       DimensionHandle cell_size = c->Dim(h_prev, 1);
75       DimensionHandle twice_cell_size = c->Dim(w_ru, 1);
76       ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size);
77 
78       c->set_output(0, x);
79       c->set_output(1, batch_cell_shape);
80       c->set_output(2, batch_cell_shape);
81       c->set_output(3, c->Matrix(batch_size, twice_cell_size));
82       return OkStatus();
83     });
84 
85 REGISTER_OP("LSTMBlockCell")
86     .Input("x: T")
87     .Input("cs_prev: T")
88     .Input("h_prev: T")
89     .Input("w: T")
90     .Input("wci: T")
91     .Input("wcf: T")
92     .Input("wco: T")
93     .Input("b: T")
94     .Output("i: T")
95     .Output("cs: T")
96     .Output("f: T")
97     .Output("o: T")
98     .Output("ci: T")
99     .Output("co: T")
100     .Output("h: T")
101     .Attr("forget_bias: float = 1.0")
102     .Attr("cell_clip: float = 3.0")
103     .Attr("use_peephole: bool = false")
104     .Attr("T: {half, float}")
__anon4a6db9170302(InferenceContext* c) 105     .SetShapeFn([](InferenceContext* c) {
106       ShapeHandle x, cs_prev;
107       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
108       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));
109 
110       DimensionHandle batch_size = c->Dim(x, 0);
111       DimensionHandle cell_size = c->Dim(cs_prev, 1);
112       ShapeHandle output = c->Matrix(batch_size, cell_size);
113       for (int i = 0; i < 7; ++i) {
114         c->set_output(i, output);
115       }
116       return OkStatus();
117     });
118 
119 REGISTER_OP("LSTMBlockCellGrad")
120     .Input("x: T")
121     .Input("cs_prev: T")
122     .Input("h_prev: T")
123     .Input("w: T")
124     .Input("wci: T")
125     .Input("wcf: T")
126     .Input("wco: T")
127     .Input("b: T")
128     .Input("i: T")
129     .Input("cs: T")
130     .Input("f: T")
131     .Input("o: T")
132     .Input("ci: T")
133     .Input("co: T")
134     .Input("cs_grad: T")
135     .Input("h_grad: T")
136     .Output("cs_prev_grad: T")
137     .Output("dicfo: T")
138     .Output("wci_grad: T")
139     .Output("wcf_grad: T")
140     .Output("wco_grad: T")
141     .Attr("use_peephole: bool")
142     .Attr("T: {half, float}")
__anon4a6db9170402(InferenceContext* c) 143     .SetShapeFn([](InferenceContext* c) {
144       ShapeHandle x, cs_prev;
145       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
146       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev));
147 
148       DimensionHandle batch_size = c->Dim(x, 0);
149       DimensionHandle cell_size = c->Dim(cs_prev, 1);
150       DimensionHandle cell_size_times_4;
151       TF_RETURN_IF_ERROR(c->Multiply(cell_size, 4, &cell_size_times_4));
152       ShapeHandle cell_size_vec = c->Vector(cell_size);
153 
154       c->set_output(0, c->Matrix(batch_size, cell_size));
155       c->set_output(1, c->Matrix(batch_size, cell_size_times_4));
156       c->set_output(2, cell_size_vec);
157       c->set_output(3, cell_size_vec);
158       c->set_output(4, cell_size_vec);
159       return OkStatus();
160     });
161 
162 REGISTER_OP("BlockLSTM")
163     .Input("seq_len_max: int64")
164     .Input("x: T")
165     .Input("cs_prev: T")
166     .Input("h_prev: T")
167     .Input("w: T")
168     .Input("wci: T")
169     .Input("wcf: T")
170     .Input("wco: T")
171     .Input("b: T")
172     .Output("i: T")
173     .Output("cs: T")
174     .Output("f: T")
175     .Output("o: T")
176     .Output("ci: T")
177     .Output("co: T")
178     .Output("h: T")
179     .Attr("forget_bias: float = 1.0")
180     .Attr("cell_clip: float = 3.0")
181     .Attr("use_peephole: bool = false")
182     .Attr("T: {half, float}")
__anon4a6db9170502(InferenceContext* c) 183     .SetShapeFn([](InferenceContext* c) {
184       ShapeHandle x, b;
185       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
186       TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
187 
188       DimensionHandle timelen = c->Dim(x, 0);
189       DimensionHandle batch_size = c->Dim(x, 1);
190       DimensionHandle cell_size;
191       TF_RETURN_IF_ERROR(
192           c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
193 
194       DCHECK_EQ(7, c->num_outputs());
195       ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
196       for (int i = 0; i < 7; ++i) {
197         c->set_output(i, output);
198       }
199       return OkStatus();
200     });
201 
202 REGISTER_OP("BlockLSTMV2")
203     .Input("seq_len_max: int64")
204     .Input("x: T")
205     .Input("cs_prev: T")
206     .Input("h_prev: T")
207     .Input("w: T")
208     .Input("wci: T")
209     .Input("wcf: T")
210     .Input("wco: T")
211     .Input("b: T")
212     .Output("i: T")
213     .Output("cs: T")
214     .Output("f: T")
215     .Output("o: T")
216     .Output("ci: T")
217     .Output("co: T")
218     .Output("h: T")
219     .Attr("cell_clip: float = 0.0")
220     .Attr("use_peephole: bool = false")
221     .Attr("T: {half, float}")
__anon4a6db9170602(InferenceContext* c) 222     .SetShapeFn([](InferenceContext* c) {
223       ShapeHandle x, b;
224       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
225       TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
226 
227       DimensionHandle timelen = c->Dim(x, 0);
228       DimensionHandle batch_size = c->Dim(x, 1);
229       DimensionHandle cell_size;
230       TF_RETURN_IF_ERROR(
231           c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
232 
233       DCHECK_EQ(7, c->num_outputs());
234       ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
235       for (int i = 0; i < 7; ++i) {
236         c->set_output(i, output);
237       }
238       return OkStatus();
239     });
240 
241 REGISTER_OP("BlockLSTMGrad")
242     .Input("seq_len_max: int64")
243     .Input("x: T")
244     .Input("cs_prev: T")
245     .Input("h_prev: T")
246     .Input("w: T")
247     .Input("wci: T")
248     .Input("wcf: T")
249     .Input("wco: T")
250     .Input("b: T")
251     .Input("i: T")
252     .Input("cs: T")
253     .Input("f: T")
254     .Input("o: T")
255     .Input("ci: T")
256     .Input("co: T")
257     .Input("h: T")
258     .Input("cs_grad: T")
259     .Input("h_grad: T")
260     .Output("x_grad: T")
261     .Output("cs_prev_grad: T")
262     .Output("h_prev_grad: T")
263     .Output("w_grad: T")
264     .Output("wci_grad: T")
265     .Output("wcf_grad: T")
266     .Output("wco_grad: T")
267     .Output("b_grad: T")
268     .Attr("use_peephole: bool")
269     .Attr("T: {half, float}")
__anon4a6db9170702(InferenceContext* c) 270     .SetShapeFn([](InferenceContext* c) {
271       ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
272       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
273       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
274       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
275       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
276       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
277       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
278       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
279       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
280 
281       c->set_output(0, x);
282       c->set_output(1, cs_prev);
283       c->set_output(2, h_prev);
284       c->set_output(3, w);
285       c->set_output(4, wci);
286       c->set_output(5, wco);
287       c->set_output(6, wcf);
288       c->set_output(7, b);
289 
290       return OkStatus();
291     });
292 
293 REGISTER_OP("BlockLSTMGradV2")
294     .Input("seq_len_max: int64")
295     .Input("x: T")
296     .Input("cs_prev: T")
297     .Input("h_prev: T")
298     .Input("w: T")
299     .Input("wci: T")
300     .Input("wcf: T")
301     .Input("wco: T")
302     .Input("b: T")
303     .Input("i: T")
304     .Input("cs: T")
305     .Input("f: T")
306     .Input("o: T")
307     .Input("ci: T")
308     .Input("co: T")
309     .Input("h: T")
310     .Input("cs_grad: T")
311     .Input("h_grad: T")
312     .Output("x_grad: T")
313     .Output("cs_prev_grad: T")
314     .Output("h_prev_grad: T")
315     .Output("w_grad: T")
316     .Output("wci_grad: T")
317     .Output("wcf_grad: T")
318     .Output("wco_grad: T")
319     .Output("b_grad: T")
320     .Attr("use_peephole: bool")
321     .Attr("T: {half, float}")
__anon4a6db9170802(InferenceContext* c) 322     .SetShapeFn([](InferenceContext* c) {
323       ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
324       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
325       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
326       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
327       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
328       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
329       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
330       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
331       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
332 
333       c->set_output(0, x);
334       c->set_output(1, cs_prev);
335       c->set_output(2, h_prev);
336       c->set_output(3, w);
337       c->set_output(4, wci);
338       c->set_output(5, wco);
339       c->set_output(6, wcf);
340       c->set_output(7, b);
341 
342       return OkStatus();
343     });
344 
345 }  // end namespace tensorflow
346