• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 using tensorflow::shape_inference::DimensionHandle;
20 using tensorflow::shape_inference::InferenceContext;
21 using tensorflow::shape_inference::ShapeHandle;
22 
23 REGISTER_OP("GRUBlockCell")
24     .Attr("T: {float}")
25     .Input("x: T")
26     .Input("h_prev: T")
27     .Input("w_ru: T")
28     .Input("w_c: T")
29     .Input("b_ru: T")
30     .Input("b_c: T")
31     .Output("r: T")
32     .Output("u: T")
33     .Output("c: T")
34     .Output("h: T")
__anon07a5ee8f0102(InferenceContext* c) 35     .SetShapeFn([](InferenceContext* c) {
36       ShapeHandle x, h_prev;
37       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
38       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
39 
40       DimensionHandle batch_size = c->Dim(x, 0);
41       DimensionHandle cell_size = c->Dim(h_prev, 1);
42       ShapeHandle output = c->Matrix(batch_size, cell_size);
43       for (int i = 0; i < 4; ++i) {
44         c->set_output(i, output);
45       }
46       return tensorflow::Status::OK();
47     })
48     .Doc(R"doc(
49 Computes the GRU cell forward propagation for 1 time step.
50 
51 Args
52     x: Input to the GRU cell.
53     h_prev: State input from the previous GRU cell.
54     w_ru: Weight matrix for the reset and update gate.
55     w_c: Weight matrix for the cell connection gate.
56     b_ru: Bias vector for the reset and update gate.
57     b_c: Bias vector for the cell connection gate.
58 
59 Returns
60     r: Output of the reset gate.
61     u: Output of the update gate.
62     c: Output of the cell connection gate.
63     h: Current state of the GRU cell.
64 
65 Note on notation of the variables:
66 
67 Concatenation of a and b is represented by a_b
68 Element-wise dot product of a and b is represented by ab
69 Element-wise dot product is represented by \circ
70 Matrix multiplication is represented by *
71 
72 Biases are initialized with :
73 `b_ru` - constant_initializer(1.0)
74 `b_c` - constant_initializer(0.0)
75 
76 This kernel op implements the following mathematical equations:
77 
78 ```
79 x_h_prev = [x, h_prev]
80 
81 [r_bar u_bar] = x_h_prev * w_ru + b_ru
82 
83 r = sigmoid(r_bar)
84 u = sigmoid(u_bar)
85 
86 h_prevr = h_prev \circ r
87 
88 x_h_prevr = [x h_prevr]
89 
90 c_bar = x_h_prevr * w_c + b_c
91 c = tanh(c_bar)
92 
93 h = (1-u) \circ c + u \circ h_prev
94 ```
95 )doc");
96 
97 REGISTER_OP("GRUBlockCellGrad")
98     .Attr("T: {float}")
99     .Input("x: T")
100     .Input("h_prev: T")
101     .Input("w_ru: T")
102     .Input("w_c: T")
103     .Input("b_ru: T")
104     .Input("b_c: T")
105     .Input("r: T")
106     .Input("u: T")
107     .Input("c: T")
108     .Input("d_h: T")
109     .Output("d_x: T")
110     .Output("d_h_prev: T")
111     .Output("d_c_bar: T")
112     .Output("d_r_bar_u_bar: T")
__anon07a5ee8f0202(InferenceContext* c) 113     .SetShapeFn([](InferenceContext* c) {
114       ShapeHandle x, h_prev, w_ru;
115       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &x));
116       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &h_prev));
117       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &w_ru));
118 
119       DimensionHandle batch_size = c->Dim(x, 0);
120       DimensionHandle cell_size = c->Dim(h_prev, 1);
121       DimensionHandle twice_cell_size = c->Dim(w_ru, 1);
122       ShapeHandle batch_cell_shape = c->Matrix(batch_size, cell_size);
123 
124       c->set_output(0, x);
125       c->set_output(1, batch_cell_shape);
126       c->set_output(2, batch_cell_shape);
127       c->set_output(3, c->Matrix(batch_size, twice_cell_size));
128       return tensorflow::Status::OK();
129     })
130     .Doc(R"doc(
131 Computes the GRU cell back-propagation for 1 time step.
132 
133 Args
134     x: Input to the GRU cell.
135     h_prev: State input from the previous GRU cell.
136     w_ru: Weight matrix for the reset and update gate.
137     w_c: Weight matrix for the cell connection gate.
138     b_ru: Bias vector for the reset and update gate.
139     b_c: Bias vector for the cell connection gate.
140     r: Output of the reset gate.
141     u: Output of the update gate.
142     c: Output of the cell connection gate.
143     d_h: Gradients of the h_new wrt to objective function.
144 
145 Returns
146     d_x: Gradients of the x wrt to objective function.
147     d_h_prev: Gradients of the h wrt to objective function.
148     d_c_bar Gradients of the c_bar wrt to objective function.
149     d_r_bar_u_bar Gradients of the r_bar & u_bar wrt to objective function.
150 
151 This kernel op implements the following mathematical equations:
152 
153 Note on notation of the variables:
154 
155 Concatenation of a and b is represented by a_b
156 Element-wise dot product of a and b is represented by ab
157 Element-wise dot product is represented by \circ
158 Matrix multiplication is represented by *
159 
160 Additional notes for clarity:
161 
162 `w_ru` can be segmented into 4 different matrices.
163 ```
164 w_ru = [w_r_x w_u_x
165         w_r_h_prev w_u_h_prev]
166 ```
167 Similarly, `w_c` can be segmented into 2 different matrices.
168 ```
169 w_c = [w_c_x w_c_h_prevr]
170 ```
171 Same goes for biases.
172 ```
173 b_ru = [b_ru_x b_ru_h]
174 b_c = [b_c_x b_c_h]
175 ```
176 Another note on notation:
177 ```
178 d_x = d_x_component_1 + d_x_component_2
179 
180 where d_x_component_1 = d_r_bar * w_r_x^T + d_u_bar * w_r_x^T
181 and d_x_component_2 = d_c_bar * w_c_x^T
182 
183 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + d_h \circ u
184 where d_h_prev_componenet_1 = d_r_bar * w_r_h_prev^T + d_u_bar * w_r_h_prev^T
185 ```
186 
187 Mathematics behind the Gradients below:
188 ```
189 d_c_bar = d_h \circ (1-u) \circ (1-c \circ c)
190 d_u_bar = d_h \circ (h-c) \circ u \circ (1-u)
191 
192 d_r_bar_u_bar = [d_r_bar d_u_bar]
193 
194 [d_x_component_1 d_h_prev_component_1] = d_r_bar_u_bar * w_ru^T
195 
196 [d_x_component_2 d_h_prevr] = d_c_bar * w_c^T
197 
198 d_x = d_x_component_1 + d_x_component_2
199 
200 d_h_prev = d_h_prev_component_1 + d_h_prevr \circ r + u
201 ```
202 Below calculation is performed in the python wrapper for the Gradients
203 (not in the gradient kernel.)
204 ```
205 d_w_ru = x_h_prevr^T * d_c_bar
206 
207 d_w_c = x_h_prev^T * d_r_bar_u_bar
208 
209 d_b_ru = sum of d_r_bar_u_bar along axis = 0
210 
211 d_b_c = sum of d_c_bar along axis = 0
212 ```
213 )doc");
214