1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/op.h" 17 #include "tensorflow/core/framework/shape_inference.h" 18 19 using tensorflow::shape_inference::DimensionHandle; 20 using tensorflow::shape_inference::InferenceContext; 21 using tensorflow::shape_inference::ShapeHandle; 22 23 REGISTER_OP("GRUBlockCell") 24 .Attr("T: {float}") 25 .Input("x: T") 26 .Input("h_prev: T") 27 .Input("w_ru: T") 28 .Input("w_c: T") 29 .Input("b_ru: T") 30 .Input("b_c: T") 31 .Output("r: T") 32 .Output("u: T") 33 .Output("c: T") 34 .Output("h: T") __anon07a5ee8f0102(InferenceContext* c) 35 .SetShapeFn([](InferenceContext* c) { 36 ShapeHandle x, h_prev; 37 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); 38 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev)); 39 40 DimensionHandle batch_size = c->Dim(x, 0); 41 DimensionHandle cell_size = c->Dim(h_prev, 1); 42 ShapeHandle output = c->Matrix(batch_size, cell_size); 43 for (int i = 0; i < 4; ++i) { 44 c->set_output(i, output); 45 } 46 return tensorflow::Status::OK(); 47 }) 48 .Doc(R"doc( 49 Computes the GRU cell forward propagation for 1 time step. 50 51 Args 52 x: Input to the GRU cell. 53 h_prev: State input from the previous GRU cell. 54 w_ru: Weight matrix for the reset and update gate. 55 w_c: Weight matrix for the cell connection gate. 56 b_ru: Bias vector for the reset and update gate. 57 b_c: Bias vector for the cell connection gate. 58 59 Returns 60 r: Output of the reset gate. 61 u: Output of the update gate. 62 c: Output of the cell connection gate. 63 h: Current state of the GRU cell. 64 65 Note on notation of the variables: 66 67 Concatenation of a and b is represented by a_b 68 Element-wise dot product of a and b is represented by ab 69 Element-wise dot product is represented by \circ 70 Matrix multiplication is represented by * 71 72 Biases are initialized with : 73 `b_ru` - constant_initializer(1.0) 74 `b_c` - constant_initializer(0.0) 75 76 This kernel op implements the following mathematical equations: 77 78 ``` 79 x_h_prev = [x, h_prev] 80 81 [r_bar u_bar] = x_h_prev * w_ru + b_ru 82 83 r = sigmoid(r_bar) 84 u = sigmoid(u_bar) 85 86 h_prevr = h_prev \circ r 87 88 x_h_prevr = [x h_prevr] 89 90 c_bar = x_h_prevr * w_c + b_c 91 c = tanh(c_bar) 92 93 h = (1-u) \circ c + u \circ h_prev 94 ``` 95 )doc"); 96 97 REGISTER_OP("GRUBlockCellGrad") 98 .Attr("T: {float}") 99 .Input("x: T") 100 .Input("h_prev: T") 101 .Input("w_ru: T") 102 .Input("w_c: T") 103 .Input("b_ru: T") 104 .Input("b_c: T") 105 .Input("r: T") 106 .Input("u: T") 107 .Input("c: T") 108 .Input("d_h: T") 109 .Output("d_x: T") 110 .Output("d_h_prev: T") 111 .Output("d_c_bar: T") 112 .Output("d_r_bar_u_bar: T") __anon07a5ee8f0202(InferenceContext* c) 113 .SetShapeFn([](InferenceContext* c) { 114 ShapeHandle x, h_prev, w_ru; 115 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x)); 116 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev)); 117 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru)); 118 119 DimensionHandle batch_size = c->Dim(x, 0); 120 DimensionHandle cell_size = c->Dim(h_prev, 1); 121 DimensionHandle twice_cell_size = c->Dim(w_ru, 1); 122 ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size); 123 124 c->set_output(0, x); 125 c->set_output(1, batch_cell_shape); 126 c->set_output(2, batch_cell_shape); 127 c->set_output(3, c->Matrix(batch_size, twice_cell_size)); 128 return tensorflow::Status::OK(); 129 }) 130 .Doc(R"doc( 131 Computes the GRU cell back-propagation for 1 time step. 132 133 Args 134 x: Input to the GRU cell. 135 h_prev: State input from the previous GRU cell. 136 w_ru: Weight matrix for the reset and update gate. 137 w_c: Weight matrix for the cell connection gate. 138 b_ru: Bias vector for the reset and update gate. 139 b_c: Bias vector for the cell connection gate. 140 r: Output of the reset gate. 141 u: Output of the update gate. 142 c: Output of the cell connection gate. 143 d_h: Gradients of the h_new wrt to objective function. 144 145 Returns 146 d_x: Gradients of the x wrt to objective function. 147 d_h_prev: Gradients of the h wrt to objective function. 148 d_c_bar Gradients of the c_bar wrt to objective function. 149 d_r_bar_u_bar Gradients of the r_bar & u_bar wrt to objective function. 150 151 This kernel op implements the following mathematical equations: 152 153 Note on notation of the variables: 154 155 Concatenation of a and b is represented by a_b 156 Element-wise dot product of a and b is represented by ab 157 Element-wise dot product is represented by \circ 158 Matrix multiplication is represented by * 159 160 Additional notes for clarity: 161 162 `w_ru` can be segmented into 4 different matrices. 163 ``` 164 w_ru = [w_r_x w_u_x 165 w_r_h_prev w_u_h_prev] 166 ``` 167 Similarly, `w_c` can be segmented into 2 different matrices. 168 ``` 169 w_c = [w_c_x w_c_h_prevr] 170 ``` 171 Same goes for biases. 172 ``` 173 b_ru = [b_ru_x b_ru_h] 174 b_c = [b_c_x b_c_h] 175 ``` 176 Another note on notation: 177 ``` 178 d_x = d_x_component_1 + d_x_component_2 179 180 where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T 181 and d_x_component_2 = d_c_bar * w_c_x^T 182 183 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u 184 where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T 185 ``` 186 187 Mathematics behind the Gradients below: 188 ``` 189 d_c_bar = d_h \circ (1-u) \circ (1-c \circ c) 190 d_u_bar = d_h \circ (h-c) \circ u \circ (1-u) 191 192 d_r_bar_u_bar = [d_r_bar d_u_bar] 193 194 [d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T 195 196 [d_x_component_2 d_h_prevr] = d_c_bar * w_c^T 197 198 d_x = d_x_component_1 + d_x_component_2 199 200 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u 201 ``` 202 Below calculation is performed in the python wrapper for the Gradients 203 (not in the gradient kernel.) 204 ``` 205 d_w_ru = x_h_prevr^T * d_c_bar 206 207 d_w_c = x_h_prev^T * d_r_bar_u_bar 208 209 d_b_ru = sum of d_r_bar_u_bar along axis = 0 210 211 d_b_c = sum of d_c_bar along axis = 0 212 ``` 213 )doc"); 214