1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/shape_inference.h" 18 19 namespace tensorflow { 20 21 using shape_inference::DimensionHandle; 22 using shape_inference::InferenceContext; 23 using shape_inference::ShapeHandle; 24 25 REGISTER_OP("GRUBlockCell") 26 .Attr("T: {float}") 27 .Input("x: T") 28 .Input("h_prev: T") 29 .Input("w_ru: T") 30 .Input("w_c: T") 31 .Input("b_ru: T") 32 .Input("b_c: T") 33 .Output("r: T") 34 .Output("u: T") 35 .Output("c: T") 36 .Output("h: T") __anon652ace880102(InferenceContext* c) 37 .SetShapeFn([](InferenceContext* c) { 38 ShapeHandle x, h_prev; 39 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); 40 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev)); 41 42 DimensionHandle batch_size = c->Dim(x, 0); 43 DimensionHandle cell_size = c->Dim(h_prev, 1); 44 ShapeHandle output = c->Matrix(batch_size, cell_size); 45 for (int i = 0; i < 4; ++i) { 46 c->set_output(i, output); 47 } 48 return tensorflow::Status::OK(); 49 }); 50 51 REGISTER_OP("GRUBlockCellGrad") 52 .Attr("T: {float}") 53 .Input("x: T") 54 .Input("h_prev: T") 55 .Input("w_ru: T") 56 .Input("w_c: T") 57 .Input("b_ru: T") 58 .Input("b_c: T") 59 .Input("r: T") 60 .Input("u: T") 61 .Input("c: T") 62 .Input("d_h: T") 63 .Output("d_x: T") 64 .Output("d_h_prev: T") 65 .Output("d_c_bar: T") 66 .Output("d_r_bar_u_bar: T") __anon652ace880202(InferenceContext* c) 67 .SetShapeFn([](InferenceContext* c) { 68 ShapeHandle x, h_prev, w_ru; 69 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); 70 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev)); 71 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru)); 72 73 DimensionHandle batch_size = c->Dim(x, 0); 74 DimensionHandle cell_size = c->Dim(h_prev, 1); 75 DimensionHandle twice_cell_size = c->Dim(w_ru, 1); 76 ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size); 77 78 c->set_output(0, x); 79 c->set_output(1, batch_cell_shape); 80 c->set_output(2, batch_cell_shape); 81 c->set_output(3, c->Matrix(batch_size, twice_cell_size)); 82 return tensorflow::Status::OK(); 83 }); 84 85 REGISTER_OP("LSTMBlockCell") 86 .Input("x: T") 87 .Input("cs_prev: T") 88 .Input("h_prev: T") 89 .Input("w: T") 90 .Input("wci: T") 91 .Input("wcf: T") 92 .Input("wco: T") 93 .Input("b: T") 94 .Output("i: T") 95 .Output("cs: T") 96 .Output("f: T") 97 .Output("o: T") 98 .Output("ci: T") 99 .Output("co: T") 100 .Output("h: T") 101 .Attr("forget_bias: float = 1.0") 102 .Attr("cell_clip: float = 3.0") 103 .Attr("use_peephole: bool = false") 104 .Attr("T: {half, float}") __anon652ace880302(InferenceContext* c) 105 .SetShapeFn([](InferenceContext* c) { 106 ShapeHandle x, cs_prev; 107 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); 108 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev)); 109 110 DimensionHandle batch_size = c->Dim(x, 0); 111 DimensionHandle cell_size = c->Dim(cs_prev, 1); 112 ShapeHandle output = c->Matrix(batch_size, cell_size); 113 for (int i = 0; i < 7; ++i) { 114 c->set_output(i, output); 115 } 116 return tensorflow::Status::OK(); 117 }); 118 119 REGISTER_OP("LSTMBlockCellGrad") 120 .Input("x: T") 121 .Input("cs_prev: T") 122 .Input("h_prev: T") 123 .Input("w: T") 124 .Input("wci: T") 125 .Input("wcf: T") 126 .Input("wco: T") 127 .Input("b: T") 128 .Input("i: T") 129 .Input("cs: T") 130 .Input("f: T") 131 .Input("o: T") 132 .Input("ci: T") 133 .Input("co: T") 134 .Input("cs_grad: T") 135 .Input("h_grad: T") 136 .Output("cs_prev_grad: T") 137 .Output("dicfo: T") 138 .Output("wci_grad: T") 139 .Output("wcf_grad: T") 140 .Output("wco_grad: T") 141 .Attr("use_peephole: bool") 142 .Attr("T: {half, float}") __anon652ace880402(InferenceContext* c) 143 .SetShapeFn([](InferenceContext* c) { 144 ShapeHandle x, cs_prev; 145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); 146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &cs_prev)); 147 148 DimensionHandle batch_size = c->Dim(x, 0); 149 DimensionHandle cell_size = c->Dim(cs_prev, 1); 150 DimensionHandle cell_size_times_4; 151 TF_RETURN_IF_ERROR(c->Multiply(cell_size, 4, &cell_size_times_4)); 152 ShapeHandle cell_size_vec = c->Vector(cell_size); 153 154 c->set_output(0, c->Matrix(batch_size, cell_size)); 155 c->set_output(1, c->Matrix(batch_size, cell_size_times_4)); 156 c->set_output(2, cell_size_vec); 157 c->set_output(3, cell_size_vec); 158 c->set_output(4, cell_size_vec); 159 return tensorflow::Status::OK(); 160 }); 161 162 REGISTER_OP("BlockLSTM") 163 .Input("seq_len_max: int64") 164 .Input("x: T") 165 .Input("cs_prev: T") 166 .Input("h_prev: T") 167 .Input("w: T") 168 .Input("wci: T") 169 .Input("wcf: T") 170 .Input("wco: T") 171 .Input("b: T") 172 .Output("i: T") 173 .Output("cs: T") 174 .Output("f: T") 175 .Output("o: T") 176 .Output("ci: T") 177 .Output("co: T") 178 .Output("h: T") 179 .Attr("forget_bias: float = 1.0") 180 .Attr("cell_clip: float = 3.0") 181 .Attr("use_peephole: bool = false") 182 .Attr("T: {half, float}") __anon652ace880502(InferenceContext* c) 183 .SetShapeFn([](InferenceContext* c) { 184 ShapeHandle x, b; 185 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); 186 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b)); 187 188 DimensionHandle timelen = c->Dim(x, 0); 189 DimensionHandle batch_size = c->Dim(x, 1); 190 DimensionHandle cell_size; 191 TF_RETURN_IF_ERROR( 192 c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size)); 193 194 DCHECK_EQ(7, c->num_outputs()); 195 ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size}); 196 for (int i = 0; i < 7; ++i) { 197 c->set_output(i, output); 198 } 199 return Status::OK(); 200 }); 201 202 REGISTER_OP("BlockLSTMV2") 203 .Input("seq_len_max: int64") 204 .Input("x: T") 205 .Input("cs_prev: T") 206 .Input("h_prev: T") 207 .Input("w: T") 208 .Input("wci: T") 209 .Input("wcf: T") 210 .Input("wco: T") 211 .Input("b: T") 212 .Output("i: T") 213 .Output("cs: T") 214 .Output("f: T") 215 .Output("o: T") 216 .Output("ci: T") 217 .Output("co: T") 218 .Output("h: T") 219 .Attr("cell_clip: float = 0.0") 220 .Attr("use_peephole: bool = false") 221 .Attr("T: {half, float}") __anon652ace880602(InferenceContext* c) 222 .SetShapeFn([](InferenceContext* c) { 223 ShapeHandle x, b; 224 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); 225 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b)); 226 227 DimensionHandle timelen = c->Dim(x, 0); 228 DimensionHandle batch_size = c->Dim(x, 1); 229 DimensionHandle cell_size; 230 TF_RETURN_IF_ERROR( 231 c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size)); 232 233 DCHECK_EQ(7, c->num_outputs()); 234 ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size}); 235 for (int i = 0; i < 7; ++i) { 236 c->set_output(i, output); 237 } 238 return Status::OK(); 239 }); 240 241 REGISTER_OP("BlockLSTMGrad") 242 .Input("seq_len_max: int64") 243 .Input("x: T") 244 .Input("cs_prev: T") 245 .Input("h_prev: T") 246 .Input("w: T") 247 .Input("wci: T") 248 .Input("wcf: T") 249 .Input("wco: T") 250 .Input("b: T") 251 .Input("i: T") 252 .Input("cs: T") 253 .Input("f: T") 254 .Input("o: T") 255 .Input("ci: T") 256 .Input("co: T") 257 .Input("h: T") 258 .Input("cs_grad: T") 259 .Input("h_grad: T") 260 .Output("x_grad: T") 261 .Output("cs_prev_grad: T") 262 .Output("h_prev_grad: T") 263 .Output("w_grad: T") 264 .Output("wci_grad: T") 265 .Output("wcf_grad: T") 266 .Output("wco_grad: T") 267 .Output("b_grad: T") 268 .Attr("use_peephole: bool") 269 .Attr("T: {half, float}") __anon652ace880702(InferenceContext* c) 270 .SetShapeFn([](InferenceContext* c) { 271 ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b; 272 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); 273 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev)); 274 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev)); 275 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w)); 276 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci)); 277 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco)); 278 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf)); 279 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b)); 280 281 c->set_output(0, x); 282 c->set_output(1, cs_prev); 283 c->set_output(2, h_prev); 284 c->set_output(3, w); 285 c->set_output(4, wci); 286 c->set_output(5, wco); 287 c->set_output(6, wcf); 288 c->set_output(7, b); 289 290 return Status::OK(); 291 }); 292 293 REGISTER_OP("BlockLSTMGradV2") 294 .Input("seq_len_max: int64") 295 .Input("x: T") 296 .Input("cs_prev: T") 297 .Input("h_prev: T") 298 .Input("w: T") 299 .Input("wci: T") 300 .Input("wcf: T") 301 .Input("wco: T") 302 .Input("b: T") 303 .Input("i: T") 304 .Input("cs: T") 305 .Input("f: T") 306 .Input("o: T") 307 .Input("ci: T") 308 .Input("co: T") 309 .Input("h: T") 310 .Input("cs_grad: T") 311 .Input("h_grad: T") 312 .Output("x_grad: T") 313 .Output("cs_prev_grad: T") 314 .Output("h_prev_grad: T") 315 .Output("w_grad: T") 316 .Output("wci_grad: T") 317 .Output("wcf_grad: T") 318 .Output("wco_grad: T") 319 .Output("b_grad: T") 320 .Attr("use_peephole: bool") 321 .Attr("T: {half, float}") __anon652ace880802(InferenceContext* c) 322 .SetShapeFn([](InferenceContext* c) { 323 ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b; 324 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); 325 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev)); 326 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev)); 327 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w)); 328 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci)); 329 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco)); 330 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf)); 331 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b)); 332 333 c->set_output(0, x); 334 c->set_output(1, cs_prev); 335 c->set_output(2, h_prev); 336 c->set_output(3, w); 337 c->set_output(4, wci); 338 c->set_output(5, wco); 339 c->set_output(6, wcf); 340 c->set_output(7, b); 341 342 return Status::OK(); 343 }); 344 345 } // end namespace tensorflow 346