1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 #include "tensorflow/core/lib/strings/strcat.h" 20 21 namespace tensorflow { 22 namespace { 23 24 constexpr auto kRNNModeAttrs = 25 "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'"; 26 27 constexpr auto kRNNInputModeAttrs = 28 "input_mode: {'linear_input', 'skip_input', 'auto_select'} = " 29 "'linear_input'"; 30 31 constexpr auto kRNNDirectionAttrs = 32 "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'"; 33 34 } // namespace 35 36 using shape_inference::DimensionHandle; 37 using shape_inference::InferenceContext; 38 using shape_inference::ShapeHandle; 39 40 REGISTER_OP("CudnnRNNParamsSize") 41 .Input("num_layers: int32") 42 .Input("num_units: int32") 43 .Input("input_size: int32") 44 .Attr("T: {float16, float32, float64}") 45 .Attr("S: {int32, int64}") 46 .Attr(kRNNModeAttrs) 47 .Attr(kRNNInputModeAttrs) 48 .Attr(kRNNDirectionAttrs) 49 .Attr("dropout: float = 0.0") 50 .Attr("seed: int = 0") 51 .Attr("seed2: int = 0") 52 .Attr("num_proj: int = 0") 53 .Output("params_size: S") __anon815c5c9f0202(InferenceContext* c) 54 .SetShapeFn([](InferenceContext* c) { 55 ShapeHandle unused; 56 // num_layers, num_units, and input_size should be scalars. 57 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); 58 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 59 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 60 61 c->set_output(0, c->Vector(1)); 62 return Status::OK(); 63 }); 64 65 REGISTER_OP("CudnnRNN") 66 .Input("input: T") 67 .Input("input_h: T") 68 .Input("input_c: T") 69 .Input("params: T") 70 .SetIsStateful() 71 .Output("output: T") 72 .Output("output_h: T") 73 .Output("output_c: T") 74 .Output("reserve_space: T") 75 .Attr("T: {float16, float32, float64}") 76 .Attr(kRNNModeAttrs) 77 .Attr(kRNNInputModeAttrs) 78 .Attr(kRNNDirectionAttrs) 79 .Attr("dropout: float = 0.0") 80 .Attr("seed: int = 0") 81 .Attr("seed2: int = 0") 82 .Attr("is_training: bool = true") __anon815c5c9f0302(InferenceContext* c) 83 .SetShapeFn([](InferenceContext* c) { 84 auto input_shape = c->input(0); 85 auto input_h_shape = c->input(1); 86 auto seq_length = c->Dim(input_shape, 0); 87 auto batch_size = c->Dim(input_shape, 1); 88 auto num_units = c->Dim(input_h_shape, 2); 89 string direction; 90 TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction)); 91 string rnn_mode; 92 TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode)); 93 int dir_count = (direction == "bidirectional") ? 2 : 1; 94 DimensionHandle output_size; 95 TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size)); 96 auto output_shape = c->MakeShape({seq_length, batch_size, output_size}); 97 auto output_h_shape = input_h_shape; 98 auto output_c_shape TF_ATTRIBUTE_UNUSED = 99 (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({}); 100 c->set_output(0, output_shape); 101 c->set_output(1, output_h_shape); 102 c->set_output(2, output_c_shape); 103 c->set_output(3, c->UnknownShape()); 104 return Status::OK(); 105 }); 106 107 REGISTER_OP("CudnnRNNV2") 108 .Input("input: T") 109 .Input("input_h: T") 110 .Input("input_c: T") 111 .Input("params: T") 112 .SetIsStateful() 113 .Output("output: T") 114 .Output("output_h: T") 115 .Output("output_c: T") 116 .Output("reserve_space: T") 117 .Output("host_reserved: int8") 118 .Attr("T: {float16, float32, float64}") 119 .Attr(kRNNModeAttrs) 120 .Attr(kRNNInputModeAttrs) 121 .Attr(kRNNDirectionAttrs) 122 .Attr("dropout: float = 0.0") 123 .Attr("seed: int = 0") 124 .Attr("seed2: int = 0") 125 .Attr("is_training: bool = true") __anon815c5c9f0402(InferenceContext* c) 126 .SetShapeFn([](InferenceContext* c) { 127 auto input_shape = c->input(0); 128 auto input_h_shape = c->input(1); 129 auto seq_length = c->Dim(input_shape, 0); 130 auto batch_size = c->Dim(input_shape, 1); 131 auto num_units = c->Dim(input_h_shape, 2); 132 string direction; 133 TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction)); 134 string rnn_mode; 135 TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode)); 136 int dir_count = (direction == "bidirectional") ? 2 : 1; 137 DimensionHandle output_size; 138 TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size)); 139 auto output_shape = c->MakeShape({seq_length, batch_size, output_size}); 140 auto output_h_shape = input_h_shape; 141 auto output_c_shape TF_ATTRIBUTE_UNUSED = 142 (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({}); 143 c->set_output(0, output_shape); 144 c->set_output(1, output_h_shape); 145 c->set_output(2, output_c_shape); 146 c->set_output(3, c->UnknownShape()); 147 c->set_output(4, c->UnknownShape()); 148 return Status::OK(); 149 }); 150 151 REGISTER_OP("CudnnRNNV3") 152 .Input("input: T") 153 .Input("input_h: T") 154 .Input("input_c: T") 155 .Input("params: T") 156 .Input("sequence_lengths: int32") 157 .SetIsStateful() 158 .Output("output: T") 159 .Output("output_h: T") 160 .Output("output_c: T") 161 .Output("reserve_space: T") 162 .Output("host_reserved: int8") 163 .Attr("T: {float16, float32, float64}") 164 .Attr(kRNNModeAttrs) 165 .Attr(kRNNInputModeAttrs) 166 .Attr(kRNNDirectionAttrs) 167 .Attr("dropout: float = 0.0") 168 .Attr("seed: int = 0") 169 .Attr("seed2: int = 0") 170 .Attr("num_proj: int = 0") 171 .Attr("is_training: bool = true") 172 .Attr("time_major: bool = true") __anon815c5c9f0502(InferenceContext* c) 173 .SetShapeFn([](InferenceContext* c) { 174 auto input_shape = c->input(0); 175 auto input_h_shape = c->input(1); 176 auto input_c_shape = c->input(2); 177 auto max_seq_length = c->Dim(input_shape, 0); 178 auto batch_size = c->Dim(input_shape, 1); 179 auto num_units = c->Dim(input_h_shape, 2); 180 string direction; 181 TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction)); 182 string rnn_mode; 183 TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode)); 184 int dir_count = (direction == "bidirectional") ? 2 : 1; 185 DimensionHandle output_size; 186 TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size)); 187 auto output_shape = 188 c->MakeShape({max_seq_length, batch_size, output_size}); 189 auto output_h_shape = input_h_shape; 190 auto output_c_shape TF_ATTRIBUTE_UNUSED = 191 (rnn_mode == "lstm") ? input_c_shape : c->MakeShape({}); 192 c->set_output(0, output_shape); 193 c->set_output(1, output_h_shape); 194 c->set_output(2, output_c_shape); 195 c->set_output(3, c->UnknownShape()); 196 c->set_output(4, c->UnknownShape()); 197 return Status::OK(); 198 }); 199 200 REGISTER_OP("CudnnRNNBackprop") 201 .Input("input: T") 202 .Input("input_h: T") 203 .Input("input_c: T") 204 .Input("params: T") 205 .Input("output: T") 206 .Input("output_h: T") 207 .Input("output_c: T") 208 .Input("output_backprop: T") 209 .Input("output_h_backprop: T") 210 .Input("output_c_backprop: T") 211 .Input("reserve_space: T") 212 .SetIsStateful() 213 .Output("input_backprop: T") 214 .Output("input_h_backprop: T") 215 .Output("input_c_backprop: T") 216 .Output("params_backprop: T") 217 .Attr("T: {float16, float32, float64}") 218 .Attr(kRNNModeAttrs) 219 .Attr(kRNNInputModeAttrs) 220 .Attr(kRNNDirectionAttrs) 221 .Attr("dropout: float = 0.0") 222 .Attr("seed: int = 0") 223 .Attr("seed2: int = 0") __anon815c5c9f0602(InferenceContext* c) 224 .SetShapeFn([](InferenceContext* c) { 225 auto input_shape = c->input(0); 226 auto input_h_shape = c->input(1); 227 auto input_c_shape = c->input(2); 228 auto params_shape = c->input(3); 229 c->set_output(0, input_shape); 230 c->set_output(1, input_h_shape); 231 c->set_output(2, input_c_shape); 232 c->set_output(3, params_shape); 233 return Status::OK(); 234 }); 235 236 REGISTER_OP("CudnnRNNBackpropV2") 237 .Input("input: T") 238 .Input("input_h: T") 239 .Input("input_c: T") 240 .Input("params: T") 241 .Input("output: T") 242 .Input("output_h: T") 243 .Input("output_c: T") 244 .Input("output_backprop: T") 245 .Input("output_h_backprop: T") 246 .Input("output_c_backprop: T") 247 .Input("reserve_space: T") 248 .Input("host_reserved: int8") 249 .SetIsStateful() 250 .Output("input_backprop: T") 251 .Output("input_h_backprop: T") 252 .Output("input_c_backprop: T") 253 .Output("params_backprop: T") 254 .Attr("T: {float16, float32, float64}") 255 .Attr(kRNNModeAttrs) 256 .Attr(kRNNInputModeAttrs) 257 .Attr(kRNNDirectionAttrs) 258 .Attr("dropout: float = 0.0") 259 .Attr("seed: int = 0") 260 .Attr("seed2: int = 0") __anon815c5c9f0702(InferenceContext* c) 261 .SetShapeFn([](InferenceContext* c) { 262 auto input_shape = c->input(0); 263 auto input_h_shape = c->input(1); 264 auto input_c_shape = c->input(2); 265 auto params_shape = c->input(3); 266 c->set_output(0, input_shape); 267 c->set_output(1, input_h_shape); 268 c->set_output(2, input_c_shape); 269 c->set_output(3, params_shape); 270 return Status::OK(); 271 }); 272 273 REGISTER_OP("CudnnRNNBackpropV3") 274 .Input("input: T") 275 .Input("input_h: T") 276 .Input("input_c: T") 277 .Input("params: T") 278 .Input("sequence_lengths: int32") 279 .Input("output: T") 280 .Input("output_h: T") 281 .Input("output_c: T") 282 .Input("output_backprop: T") 283 .Input("output_h_backprop: T") 284 .Input("output_c_backprop: T") 285 .Input("reserve_space: T") 286 .Input("host_reserved: int8") 287 .SetIsStateful() 288 .Output("input_backprop: T") 289 .Output("input_h_backprop: T") 290 .Output("input_c_backprop: T") 291 .Output("params_backprop: T") 292 .Attr("T: {float16, float32, float64}") 293 .Attr(kRNNModeAttrs) 294 .Attr(kRNNInputModeAttrs) 295 .Attr(kRNNDirectionAttrs) 296 .Attr("dropout: float = 0.0") 297 .Attr("seed: int = 0") 298 .Attr("seed2: int = 0") 299 .Attr("num_proj: int = 0") 300 .Attr("time_major: bool = true") __anon815c5c9f0802(InferenceContext* c) 301 .SetShapeFn([](InferenceContext* c) { 302 auto input_shape = c->input(0); 303 auto input_h_shape = c->input(1); 304 auto input_c_shape = c->input(2); 305 auto params_shape = c->input(3); 306 c->set_output(0, input_shape); 307 c->set_output(1, input_h_shape); 308 c->set_output(2, input_c_shape); 309 c->set_output(3, params_shape); 310 return Status::OK(); 311 }); 312 313 REGISTER_OP("CudnnRNNParamsToCanonical") 314 .Input("num_layers: int32") 315 .Input("num_units: int32") 316 .Input("input_size: int32") 317 .Input("params: T") 318 .Output("weights: num_params * T") 319 .Output("biases: num_params * T") 320 .Attr("T: {float16, float32, float64}") 321 .Attr("num_params: int") 322 .Attr(kRNNModeAttrs) 323 .Attr(kRNNInputModeAttrs) 324 .Attr(kRNNDirectionAttrs) 325 .Attr("dropout: float = 0.0") 326 .Attr("seed: int = 0") 327 .Attr("seed2: int = 0") __anon815c5c9f0902(InferenceContext* c) 328 .SetShapeFn([](InferenceContext* c) { 329 ShapeHandle unused; 330 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused)); 331 int num_params; 332 TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params)); 333 // Set shape for weight matrices 334 for (int i = 0; i < num_params; i++) { 335 c->set_output(i, c->Matrix(InferenceContext::kUnknownDim, 336 InferenceContext::kUnknownDim)); 337 } 338 // Set shape for bias vectors 339 for (int i = 0; i < num_params; i++) { 340 c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim)); 341 } 342 return Status::OK(); 343 }); 344 345 REGISTER_OP("CudnnRNNParamsToCanonicalV2") 346 .Input("num_layers: int32") 347 .Input("num_units: int32") 348 .Input("input_size: int32") 349 .Input("params: T") 350 .Output("weights: num_params_weights * T") 351 .Output("biases: num_params_biases * T") 352 .Attr("T: {float16, float32, float64}") 353 .Attr("num_params_weights: int") 354 .Attr("num_params_biases: int") 355 .Attr(kRNNModeAttrs) 356 .Attr(kRNNInputModeAttrs) 357 .Attr(kRNNDirectionAttrs) 358 .Attr("dropout: float = 0.0") 359 .Attr("seed: int = 0") 360 .Attr("seed2: int = 0") 361 .Attr("num_proj: int = 0") __anon815c5c9f0a02(InferenceContext* c) 362 .SetShapeFn([](InferenceContext* c) { 363 ShapeHandle unused; 364 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused)); 365 int num_params_weights; 366 int num_params_biases; 367 TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights)); 368 TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases)); 369 // Set shape for weight matrices 370 for (int i = 0; i < num_params_weights; i++) { 371 c->set_output(i, c->Matrix(InferenceContext::kUnknownDim, 372 InferenceContext::kUnknownDim)); 373 } 374 // Set shape for bias vectors 375 for (int i = 0; i < num_params_biases; i++) { 376 c->set_output(num_params_weights + i, 377 c->Vector(InferenceContext::kUnknownDim)); 378 } 379 return Status::OK(); 380 }); 381 382 REGISTER_OP("CudnnRNNCanonicalToParams") 383 .Input("num_layers: int32") 384 .Input("num_units: int32") 385 .Input("input_size: int32") 386 .Input("weights: num_params * T") 387 .Input("biases: num_params * T") 388 .Output("params: T") 389 .Attr("T: {float16, float32, float64}") 390 .Attr("num_params: int") 391 .Attr(kRNNModeAttrs) 392 .Attr(kRNNInputModeAttrs) 393 .Attr(kRNNDirectionAttrs) 394 .Attr("dropout: float = 0.0") 395 .Attr("seed: int = 0") 396 .Attr("seed2: int = 0") __anon815c5c9f0b02(InferenceContext* c) 397 .SetShapeFn([](InferenceContext* c) { 398 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 399 return Status::OK(); 400 }); 401 402 REGISTER_OP("CudnnRNNCanonicalToParamsV2") 403 .Input("num_layers: int32") 404 .Input("num_units: int32") 405 .Input("input_size: int32") 406 .Input("weights: num_params_weights * T") 407 .Input("biases: num_params_biases * T") 408 .Output("params: T") 409 .Attr("T: {float16, float32, float64}") 410 .Attr("num_params_weights: int") 411 .Attr("num_params_biases: int") 412 .Attr(kRNNModeAttrs) 413 .Attr(kRNNInputModeAttrs) 414 .Attr(kRNNDirectionAttrs) 415 .Attr("dropout: float = 0.0") 416 .Attr("seed: int = 0") 417 .Attr("seed2: int = 0") 418 .Attr("num_proj: int = 0") __anon815c5c9f0c02(InferenceContext* c) 419 .SetShapeFn([](InferenceContext* c) { 420 c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); 421 return Status::OK(); 422 }); 423 424 } // namespace tensorflow 425