• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 
21 namespace tensorflow {
22 namespace {
23 
24 constexpr auto kRNNModeAttrs =
25     "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'";
26 
27 constexpr auto kRNNInputModeAttrs =
28     "input_mode: {'linear_input', 'skip_input', 'auto_select'} = "
29     "'linear_input'";
30 
31 constexpr auto kRNNDirectionAttrs =
32     "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";
33 
34 }  // namespace
35 
36 using shape_inference::DimensionHandle;
37 using shape_inference::InferenceContext;
38 using shape_inference::ShapeHandle;
39 
40 REGISTER_OP("CudnnRNNParamsSize")
41     .Input("num_layers: int32")
42     .Input("num_units: int32")
43     .Input("input_size: int32")
44     .Attr("T: {float16, float32, float64}")
45     .Attr("S: {int32, int64}")
46     .Attr(kRNNModeAttrs)
47     .Attr(kRNNInputModeAttrs)
48     .Attr(kRNNDirectionAttrs)
49     .Attr("dropout: float = 0.0")
50     .Attr("seed: int = 0")
51     .Attr("seed2: int = 0")
52     .Attr("num_proj: int = 0")
53     .Output("params_size: S")
__anon815c5c9f0202(InferenceContext* c) 54     .SetShapeFn([](InferenceContext* c) {
55       ShapeHandle unused;
56       // num_layers, num_units, and input_size should be scalars.
57       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
58       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
59       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
60 
61       c->set_output(0, c->Vector(1));
62       return Status::OK();
63     });
64 
65 REGISTER_OP("CudnnRNN")
66     .Input("input: T")
67     .Input("input_h: T")
68     .Input("input_c: T")
69     .Input("params: T")
70     .SetIsStateful()
71     .Output("output: T")
72     .Output("output_h: T")
73     .Output("output_c: T")
74     .Output("reserve_space: T")
75     .Attr("T: {float16, float32, float64}")
76     .Attr(kRNNModeAttrs)
77     .Attr(kRNNInputModeAttrs)
78     .Attr(kRNNDirectionAttrs)
79     .Attr("dropout: float = 0.0")
80     .Attr("seed: int = 0")
81     .Attr("seed2: int = 0")
82     .Attr("is_training: bool = true")
__anon815c5c9f0302(InferenceContext* c) 83     .SetShapeFn([](InferenceContext* c) {
84       auto input_shape = c->input(0);
85       auto input_h_shape = c->input(1);
86       auto seq_length = c->Dim(input_shape, 0);
87       auto batch_size = c->Dim(input_shape, 1);
88       auto num_units = c->Dim(input_h_shape, 2);
89       string direction;
90       TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
91       string rnn_mode;
92       TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
93       int dir_count = (direction == "bidirectional") ? 2 : 1;
94       DimensionHandle output_size;
95       TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
96       auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
97       auto output_h_shape = input_h_shape;
98       auto output_c_shape TF_ATTRIBUTE_UNUSED =
99           (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
100       c->set_output(0, output_shape);
101       c->set_output(1, output_h_shape);
102       c->set_output(2, output_c_shape);
103       c->set_output(3, c->UnknownShape());
104       return Status::OK();
105     });
106 
107 REGISTER_OP("CudnnRNNV2")
108     .Input("input: T")
109     .Input("input_h: T")
110     .Input("input_c: T")
111     .Input("params: T")
112     .SetIsStateful()
113     .Output("output: T")
114     .Output("output_h: T")
115     .Output("output_c: T")
116     .Output("reserve_space: T")
117     .Output("host_reserved: int8")
118     .Attr("T: {float16, float32, float64}")
119     .Attr(kRNNModeAttrs)
120     .Attr(kRNNInputModeAttrs)
121     .Attr(kRNNDirectionAttrs)
122     .Attr("dropout: float = 0.0")
123     .Attr("seed: int = 0")
124     .Attr("seed2: int = 0")
125     .Attr("is_training: bool = true")
__anon815c5c9f0402(InferenceContext* c) 126     .SetShapeFn([](InferenceContext* c) {
127       auto input_shape = c->input(0);
128       auto input_h_shape = c->input(1);
129       auto seq_length = c->Dim(input_shape, 0);
130       auto batch_size = c->Dim(input_shape, 1);
131       auto num_units = c->Dim(input_h_shape, 2);
132       string direction;
133       TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
134       string rnn_mode;
135       TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
136       int dir_count = (direction == "bidirectional") ? 2 : 1;
137       DimensionHandle output_size;
138       TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
139       auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
140       auto output_h_shape = input_h_shape;
141       auto output_c_shape TF_ATTRIBUTE_UNUSED =
142           (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
143       c->set_output(0, output_shape);
144       c->set_output(1, output_h_shape);
145       c->set_output(2, output_c_shape);
146       c->set_output(3, c->UnknownShape());
147       c->set_output(4, c->UnknownShape());
148       return Status::OK();
149     });
150 
151 REGISTER_OP("CudnnRNNV3")
152     .Input("input: T")
153     .Input("input_h: T")
154     .Input("input_c: T")
155     .Input("params: T")
156     .Input("sequence_lengths: int32")
157     .SetIsStateful()
158     .Output("output: T")
159     .Output("output_h: T")
160     .Output("output_c: T")
161     .Output("reserve_space: T")
162     .Output("host_reserved: int8")
163     .Attr("T: {float16, float32, float64}")
164     .Attr(kRNNModeAttrs)
165     .Attr(kRNNInputModeAttrs)
166     .Attr(kRNNDirectionAttrs)
167     .Attr("dropout: float = 0.0")
168     .Attr("seed: int = 0")
169     .Attr("seed2: int = 0")
170     .Attr("num_proj: int = 0")
171     .Attr("is_training: bool = true")
172     .Attr("time_major: bool = true")
__anon815c5c9f0502(InferenceContext* c) 173     .SetShapeFn([](InferenceContext* c) {
174       auto input_shape = c->input(0);
175       auto input_h_shape = c->input(1);
176       auto input_c_shape = c->input(2);
177       auto max_seq_length = c->Dim(input_shape, 0);
178       auto batch_size = c->Dim(input_shape, 1);
179       auto num_units = c->Dim(input_h_shape, 2);
180       string direction;
181       TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
182       string rnn_mode;
183       TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
184       int dir_count = (direction == "bidirectional") ? 2 : 1;
185       DimensionHandle output_size;
186       TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
187       auto output_shape =
188           c->MakeShape({max_seq_length, batch_size, output_size});
189       auto output_h_shape = input_h_shape;
190       auto output_c_shape TF_ATTRIBUTE_UNUSED =
191           (rnn_mode == "lstm") ? input_c_shape : c->MakeShape({});
192       c->set_output(0, output_shape);
193       c->set_output(1, output_h_shape);
194       c->set_output(2, output_c_shape);
195       c->set_output(3, c->UnknownShape());
196       c->set_output(4, c->UnknownShape());
197       return Status::OK();
198     });
199 
200 REGISTER_OP("CudnnRNNBackprop")
201     .Input("input: T")
202     .Input("input_h: T")
203     .Input("input_c: T")
204     .Input("params: T")
205     .Input("output: T")
206     .Input("output_h: T")
207     .Input("output_c: T")
208     .Input("output_backprop: T")
209     .Input("output_h_backprop: T")
210     .Input("output_c_backprop: T")
211     .Input("reserve_space: T")
212     .SetIsStateful()
213     .Output("input_backprop: T")
214     .Output("input_h_backprop: T")
215     .Output("input_c_backprop: T")
216     .Output("params_backprop: T")
217     .Attr("T: {float16, float32, float64}")
218     .Attr(kRNNModeAttrs)
219     .Attr(kRNNInputModeAttrs)
220     .Attr(kRNNDirectionAttrs)
221     .Attr("dropout: float = 0.0")
222     .Attr("seed: int = 0")
223     .Attr("seed2: int = 0")
__anon815c5c9f0602(InferenceContext* c) 224     .SetShapeFn([](InferenceContext* c) {
225       auto input_shape = c->input(0);
226       auto input_h_shape = c->input(1);
227       auto input_c_shape = c->input(2);
228       auto params_shape = c->input(3);
229       c->set_output(0, input_shape);
230       c->set_output(1, input_h_shape);
231       c->set_output(2, input_c_shape);
232       c->set_output(3, params_shape);
233       return Status::OK();
234     });
235 
236 REGISTER_OP("CudnnRNNBackpropV2")
237     .Input("input: T")
238     .Input("input_h: T")
239     .Input("input_c: T")
240     .Input("params: T")
241     .Input("output: T")
242     .Input("output_h: T")
243     .Input("output_c: T")
244     .Input("output_backprop: T")
245     .Input("output_h_backprop: T")
246     .Input("output_c_backprop: T")
247     .Input("reserve_space: T")
248     .Input("host_reserved: int8")
249     .SetIsStateful()
250     .Output("input_backprop: T")
251     .Output("input_h_backprop: T")
252     .Output("input_c_backprop: T")
253     .Output("params_backprop: T")
254     .Attr("T: {float16, float32, float64}")
255     .Attr(kRNNModeAttrs)
256     .Attr(kRNNInputModeAttrs)
257     .Attr(kRNNDirectionAttrs)
258     .Attr("dropout: float = 0.0")
259     .Attr("seed: int = 0")
260     .Attr("seed2: int = 0")
__anon815c5c9f0702(InferenceContext* c) 261     .SetShapeFn([](InferenceContext* c) {
262       auto input_shape = c->input(0);
263       auto input_h_shape = c->input(1);
264       auto input_c_shape = c->input(2);
265       auto params_shape = c->input(3);
266       c->set_output(0, input_shape);
267       c->set_output(1, input_h_shape);
268       c->set_output(2, input_c_shape);
269       c->set_output(3, params_shape);
270       return Status::OK();
271     });
272 
273 REGISTER_OP("CudnnRNNBackpropV3")
274     .Input("input: T")
275     .Input("input_h: T")
276     .Input("input_c: T")
277     .Input("params: T")
278     .Input("sequence_lengths: int32")
279     .Input("output: T")
280     .Input("output_h: T")
281     .Input("output_c: T")
282     .Input("output_backprop: T")
283     .Input("output_h_backprop: T")
284     .Input("output_c_backprop: T")
285     .Input("reserve_space: T")
286     .Input("host_reserved: int8")
287     .SetIsStateful()
288     .Output("input_backprop: T")
289     .Output("input_h_backprop: T")
290     .Output("input_c_backprop: T")
291     .Output("params_backprop: T")
292     .Attr("T: {float16, float32, float64}")
293     .Attr(kRNNModeAttrs)
294     .Attr(kRNNInputModeAttrs)
295     .Attr(kRNNDirectionAttrs)
296     .Attr("dropout: float = 0.0")
297     .Attr("seed: int = 0")
298     .Attr("seed2: int = 0")
299     .Attr("num_proj: int = 0")
300     .Attr("time_major: bool = true")
__anon815c5c9f0802(InferenceContext* c) 301     .SetShapeFn([](InferenceContext* c) {
302       auto input_shape = c->input(0);
303       auto input_h_shape = c->input(1);
304       auto input_c_shape = c->input(2);
305       auto params_shape = c->input(3);
306       c->set_output(0, input_shape);
307       c->set_output(1, input_h_shape);
308       c->set_output(2, input_c_shape);
309       c->set_output(3, params_shape);
310       return Status::OK();
311     });
312 
313 REGISTER_OP("CudnnRNNParamsToCanonical")
314     .Input("num_layers: int32")
315     .Input("num_units: int32")
316     .Input("input_size: int32")
317     .Input("params: T")
318     .Output("weights: num_params * T")
319     .Output("biases: num_params * T")
320     .Attr("T: {float16, float32, float64}")
321     .Attr("num_params: int")
322     .Attr(kRNNModeAttrs)
323     .Attr(kRNNInputModeAttrs)
324     .Attr(kRNNDirectionAttrs)
325     .Attr("dropout: float = 0.0")
326     .Attr("seed: int = 0")
327     .Attr("seed2: int = 0")
__anon815c5c9f0902(InferenceContext* c) 328     .SetShapeFn([](InferenceContext* c) {
329       ShapeHandle unused;
330       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
331       int num_params;
332       TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params));
333       // Set shape for weight matrices
334       for (int i = 0; i < num_params; i++) {
335         c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
336                                    InferenceContext::kUnknownDim));
337       }
338       // Set shape for bias vectors
339       for (int i = 0; i < num_params; i++) {
340         c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim));
341       }
342       return Status::OK();
343     });
344 
345 REGISTER_OP("CudnnRNNParamsToCanonicalV2")
346     .Input("num_layers: int32")
347     .Input("num_units: int32")
348     .Input("input_size: int32")
349     .Input("params: T")
350     .Output("weights: num_params_weights * T")
351     .Output("biases: num_params_biases * T")
352     .Attr("T: {float16, float32, float64}")
353     .Attr("num_params_weights: int")
354     .Attr("num_params_biases: int")
355     .Attr(kRNNModeAttrs)
356     .Attr(kRNNInputModeAttrs)
357     .Attr(kRNNDirectionAttrs)
358     .Attr("dropout: float = 0.0")
359     .Attr("seed: int = 0")
360     .Attr("seed2: int = 0")
361     .Attr("num_proj: int = 0")
__anon815c5c9f0a02(InferenceContext* c) 362     .SetShapeFn([](InferenceContext* c) {
363       ShapeHandle unused;
364       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
365       int num_params_weights;
366       int num_params_biases;
367       TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights));
368       TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases));
369       // Set shape for weight matrices
370       for (int i = 0; i < num_params_weights; i++) {
371         c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
372                                    InferenceContext::kUnknownDim));
373       }
374       // Set shape for bias vectors
375       for (int i = 0; i < num_params_biases; i++) {
376         c->set_output(num_params_weights + i,
377                       c->Vector(InferenceContext::kUnknownDim));
378       }
379       return Status::OK();
380     });
381 
382 REGISTER_OP("CudnnRNNCanonicalToParams")
383     .Input("num_layers: int32")
384     .Input("num_units: int32")
385     .Input("input_size: int32")
386     .Input("weights: num_params * T")
387     .Input("biases: num_params * T")
388     .Output("params: T")
389     .Attr("T: {float16, float32, float64}")
390     .Attr("num_params: int")
391     .Attr(kRNNModeAttrs)
392     .Attr(kRNNInputModeAttrs)
393     .Attr(kRNNDirectionAttrs)
394     .Attr("dropout: float = 0.0")
395     .Attr("seed: int = 0")
396     .Attr("seed2: int = 0")
__anon815c5c9f0b02(InferenceContext* c) 397     .SetShapeFn([](InferenceContext* c) {
398       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
399       return Status::OK();
400     });
401 
402 REGISTER_OP("CudnnRNNCanonicalToParamsV2")
403     .Input("num_layers: int32")
404     .Input("num_units: int32")
405     .Input("input_size: int32")
406     .Input("weights: num_params_weights * T")
407     .Input("biases: num_params_biases * T")
408     .Output("params: T")
409     .Attr("T: {float16, float32, float64}")
410     .Attr("num_params_weights: int")
411     .Attr("num_params_biases: int")
412     .Attr(kRNNModeAttrs)
413     .Attr(kRNNInputModeAttrs)
414     .Attr(kRNNDirectionAttrs)
415     .Attr("dropout: float = 0.0")
416     .Attr("seed: int = 0")
417     .Attr("seed2: int = 0")
418     .Attr("num_proj: int = 0")
__anon815c5c9f0c02(InferenceContext* c) 419     .SetShapeFn([](InferenceContext* c) {
420       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
421       return Status::OK();
422     });
423 
424 }  // namespace tensorflow
425