• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/function.h"
17 #include "tensorflow/core/lib/core/errors.h"
18 #include "tensorflow/core/util/padding.h"
19 #include "tensorflow/core/util/tensor_format.h"
20 
21 namespace tensorflow {
22 
23 typedef FunctionDefHelper FDH;
24 
SoftmaxGrad(const AttrSlice & attrs,FunctionDef * g)25 Status SoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
26   // clang-format off
27   *g = FDH::Define(
28       "SoftmaxGrad",
29       // Arg defs
30       {"x: T", "grad_softmax: T"},
31       // Ret val defs
32       {"grad_x: T"},
33       // Attr defs
34       {{"T: {float, double}"}},
35       // Nodes
36       // Based on _SoftmaxGrad in nn_grad.py.
37       {
38         {{"softmax"}, "Softmax", {"x"}, {{"T", "$T"}}},
39         {{"n0"}, "Mul", {"grad_softmax", "softmax"}, {{"T", "$T"}}},
40         FDH::Const<int32>("indices", {-1}),
41         {{"n1"}, "Sum", {"n0", "indices"}, {{"keep_dims", true}, {"T", "$T"}}},
42         {{"n2"}, "Sub", {"grad_softmax", "n1"}, {{"T", "$T"}}},
43         {{"grad_x"}, "Mul", {"n2", "softmax"}, {{"T", "$T"}}}
44       });
45   // clang-format on
46   return Status::OK();
47 }
48 REGISTER_OP_GRADIENT("Softmax", SoftmaxGrad);
49 
LogSoftmaxGrad(const AttrSlice & attrs,FunctionDef * g)50 Status LogSoftmaxGrad(const AttrSlice& attrs, FunctionDef* g) {
51   // clang-format off
52   *g = FDH::Define(
53       "LogSoftmaxGrad",
54       // Arg defs
55       {"x: T", "grad_logsoftmax: T"},
56       // Ret val defs
57       {"grad_x: T"},
58       // Attr defs
59       {{"T: {float, double}"}},
60       // Nodes
61       // Based on _LogSoftmaxGrad in nn_grad.py.
62       {
63         {{"softmax"}, "Softmax", {"x"}, {{"T", "$T"}}},
64         FDH::Const<int32>("indices", {-1}),
65         {{"n0"}, "Sum", {"grad_logsoftmax", "indices"},
66          {{"keep_dims", true}, {"T", "$T"}}},
67         {{"n1"}, "Mul", {"n0", "softmax"}, {{"T", "$T"}}},
68         {{"grad_x"}, "Sub", {"grad_logsoftmax", "n1"}, {{"T", "$T"}}}
69       });
70   // clang-format on
71   return Status::OK();
72 }
73 REGISTER_OP_GRADIENT("LogSoftmax", LogSoftmaxGrad);
74 
ReluGrad(const AttrSlice & attrs,FunctionDef * g)75 Status ReluGrad(const AttrSlice& attrs, FunctionDef* g) {
76   // clang-format off
77   *g = FDH::Define(
78       // Arg defs
79       {"x: T", "dy: T"},
80       // Ret val defs
81       {"dx: T"},
82       // Attr defs
83       {{"T: {float, double}"}},
84       // Nodes
85       {
86         {{"dx"}, "ReluGrad", {"dy", "x"}, {{"T", "$T"}}}
87       });
88   // clang-format on
89   return Status::OK();
90 }
91 REGISTER_OP_GRADIENT("Relu", ReluGrad);
92 
Relu6Grad(const AttrSlice & attrs,FunctionDef * g)93 Status Relu6Grad(const AttrSlice& attrs, FunctionDef* g) {
94   // clang-format off
95   *g = FDH::Define(
96       // Arg defs
97       {"x: T", "dy: T"},
98       // Ret val defs
99       {"dx: T"},
100       // Attr defs
101       {{"T: {float, double}"}},
102       // Nodes
103       {
104         {{"dx"}, "Relu6Grad", {"dy", "x"}, {{"T", "$T"}}}
105       });
106   // clang-format on
107   return Status::OK();
108 }
109 REGISTER_OP_GRADIENT("Relu6", Relu6Grad);
110 
CrossEntropyGrad(const AttrSlice & attrs,FunctionDef * g)111 Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) {
112   // clang-format off
113   *g = FDH::Define(
114     // Arg defs
115     {"features: T", "labels: T", "dcost_dloss: T", "donotcare: T"},
116     // Ret val defs
117     {"dcost_dfeatures: T", "dcost_dlabels: T"},
118     // Attr defs
119     {{"T: {float, double}"}},
120     // Nodes
121     {
122       // _, dloss_dfeatures = CrossEntropy(features, labels)
123       {{"donotcare_loss", "dloss_dfeatures"}, "CrossEntropy",
124        {"features", "labels"}, {{"T", "$T"}}},
125       // dcost_dloss is of shape [batch_size].
126       // dcost_dloss_mat is of shape [batch_size, 1].
127       FDH::Const("neg1", -1),
128       {{"dcost_dloss_mat"}, "ExpandDims", {"dcost_dloss", "neg1"},
129        {{"T", "$T"}}},
130       // chain rule: dcost/dfeatures = dcost/dloss * dloss/dfeatures
131       {{"dcost_dfeatures"}, "Mul", {"dcost_dloss_mat", "dloss_dfeatures"},
132        {{"T", "$T"}}},
133       {{"dcost_dlabels"}, "ZerosLike", {"labels"}, {{"T", "$T"}}},
134     });
135   // clang-format on
136   return Status::OK();
137 }
138 REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad);
139 
Conv2DGrad(const AttrSlice & attrs,FunctionDef * g)140 Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) {
141   // clang-format off
142   *g = FDH::Define(
143     // Arg defs
144     {"input: T", "filter: T", "grad: T"},
145     // Ret val defs
146     {"input_grad: T", "filter_grad: T"},
147     // Attr defs
148     {"T: {float, double}",
149      "strides: list(int)",
150      "use_cudnn_on_gpu: bool = true",
151      GetPaddingAttrString(),
152      GetConvnetDataFormatAttrString()},
153     // Nodes
154     {
155       {{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}},
156       {{"input_grad"}, "Conv2DBackpropInput", {"i_shape", "filter", "grad"},
157        /*Attrs=*/{{"T", "$T"},
158                   {"strides", "$strides"},
159                   {"padding", "$padding"},
160                   {"data_format", "$data_format"},
161                   {"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
162 
163       {{"f_shape"}, "Shape", {"filter"}, {{"T", "$T"}}},
164       {{"filter_grad"}, "Conv2DBackpropFilter", {"input", "f_shape", "grad"},
165        /*Attrs=*/{{"T", "$T"},
166                   {"strides", "$strides"},
167                   {"padding", "$padding"},
168                   {"data_format", "$data_format"},
169                   {"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
170     });
171   // clang-format on
172   return Status::OK();
173 }
174 REGISTER_OP_GRADIENT("Conv2D", Conv2DGrad);
175 
MaxPoolGrad(const AttrSlice & attrs,FunctionDef * g)176 Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
177   // clang-format off
178   *g = FDH::Define(
179     // Arg defs
180     {"input: T", "grad: T"},
181     // Ret val defs
182     {"output: T"},
183     // Attr defs
184     {"T: {float, half} = DT_FLOAT",
185      "ksize: list(int) >= 4",
186      "strides: list(int) >= 4",
187      GetPaddingAttrString()},
188     // Nodes
189     {
190       // Invoke MaxPool again to recompute the outputs (removed by CSE?).
191       {{"maxpool"}, "MaxPool", {"input"},
192        /*Attrs=*/{{"T", "$T"},
193                   {"ksize", "$ksize"},
194                   {"strides", "$strides"},
195                   {"padding", "$padding"}}},
196       {{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"},
197        /*Attrs=*/{{"T", "$T"},
198                   {"ksize", "$ksize"},
199                   {"strides", "$strides"},
200                   {"padding", "$padding"}}}
201     });
202   // clang-format on
203   return Status::OK();
204 }
205 REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad);
206 
AvgPoolGrad(const AttrSlice & attrs,FunctionDef * g)207 Status AvgPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
208   // clang-format off
209   *g = FDH::Define(
210     // Arg defs
211     {"input: T", "grad: T"},
212     // Ret val defs
213     {"output: T"},
214     // Attr defs
215     {"T: {float, half} = DT_FLOAT",
216      "ksize: list(int) >= 4",
217      "strides: list(int) >= 4",
218      GetPaddingAttrString()},
219     // Nodes
220     {
221       {{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}},
222       {{"output"}, "AvgPoolGrad", {"i_shape", "grad"},
223        /*Attrs=*/{{"T", "$T"},
224                   {"ksize", "$ksize"},
225                   {"strides", "$strides"},
226                   {"padding", "$padding"}}}
227     });
228   // clang-format on
229   return Status::OK();
230 }
231 REGISTER_OP_GRADIENT("AvgPool", AvgPoolGrad);
232 
MaxPoolGradGrad(const AttrSlice & attrs,FunctionDef * g)233 Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) {
234   // clang-format off
235   *g = FDH::Define(
236     // Arg defs
237     {"input: T", "grad: T"},
238     // Ret val defs
239     {"output: T"},
240     // Attr defs
241     {"T: {float, half} = DT_FLOAT",
242      "ksize: list(int) >= 4",
243      "strides: list(int) >= 4",
244      GetPaddingAttrString()},
245     // Nodes
246     {
247       // Invoke MaxPool again to recompute the outputs (removed by CSE?).
248       {{"maxpool"}, "MaxPool", {"input"},
249        /*Attrs=*/{{"T", "$T"},
250                   {"ksize", "$ksize"},
251                   {"strides", "$strides"},
252                   {"padding", "$padding"}}},
253       {{"output"}, "MaxPoolGradGrad", {"input", "maxpool", "grad"},
254        /*Attrs=*/{{"T", "$T"},
255                   {"ksize", "$ksize"},
256                   {"strides", "$strides"},
257                   {"padding", "$padding"}}}
258     });
259   // clang-format on
260   return Status::OK();
261 }
262 REGISTER_OP_GRADIENT("MaxPoolGrad", MaxPoolGradGrad);
263 
BiasAddGrad(const AttrSlice & attrs,FunctionDef * g)264 Status BiasAddGrad(const AttrSlice& attrs, FunctionDef* g) {
265   // clang-format off
266   *g = FDH::Define(
267     // Arg defs
268     {"input: T", "bias: T", "grad: T"},
269     // Ret val defs
270     {"grad: T", "bias_grad: T"},
271     // Attr defs
272     {{"T: {float, double}"},
273      GetConvnetDataFormatAttrString()},
274     // Nodes
275     {
276       {{"bias_grad"}, "BiasAddGrad", {"grad"},
277            /*Attrs=*/{{"T", "$T"},
278                       {"data_format", "$data_format"}}}
279     });
280   // clang-format on
281   return Status::OK();
282 }
283 REGISTER_OP_GRADIENT("BiasAdd", BiasAddGrad);
284 
285 }  // end namespace tensorflow
286