• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 #include "tensorflow/core/framework/function.h"
18 #include "tensorflow/core/lib/core/errors.h"
19 
20 namespace tensorflow {
21 
22 typedef FunctionDefHelper FDH;
23 
24 // Cwise binary ops
GradForUnaryCwise(FunctionDef * g,std::vector<FDH::Node> nodes)25 Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) {
26   for (auto& n : nodes) {
27     if (n.attr.empty()) {
28       n.attr = {{"T", "$T"}};
29     }
30   }
31   *g = FDH::Define(
32       // Arg defs
33       {"x: T", "dy: T"},
34       // Ret val defs
35       {"dx: T"},
36       // Attr defs
37       {{"T: {half, float, double}"}},
38       // Nodes
39       nodes);
40   return Status::OK();
41 }
42 
AbsGrad(const AttrSlice & attrs,FunctionDef * g)43 Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) {
44   // clang-format off
45   return GradForUnaryCwise(g, {
46       {{"sign"}, "Sign", {"x"}, {}, {"dy"}},
47       {{"dx"}, "Mul", {"dy", "sign"}},
48   });
49   // clang-format on
50 }
51 REGISTER_OP_GRADIENT("Abs", AbsGrad);
52 
NegGrad(const AttrSlice & attrs,FunctionDef * g)53 Status NegGrad(const AttrSlice& attrs, FunctionDef* g) {
54   // clang-format off
55   return GradForUnaryCwise(g, {
56       {{"dx"}, "Neg", {"dy"}},
57   });
58   // clang-format on
59 }
60 REGISTER_OP_GRADIENT("Neg", NegGrad);
61 
InvGrad(const AttrSlice & attrs,FunctionDef * g)62 Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
63   // clang-format off
64   return GradForUnaryCwise(g, {
65       {{"y"}, "Reciprocal", {"x"}},
66       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
67       {{"y2_neg"}, "Neg", {"y2"}},
68       {{"dx"}, "Mul", {"dy", "y2_neg"}}
69   });
70   // clang-format on
71 }
72 REGISTER_OP_GRADIENT("Inv", InvGrad);
73 REGISTER_OP_GRADIENT("Reciprocal", InvGrad);
74 
SquareGrad(const AttrSlice & attrs,FunctionDef * g)75 Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
76   // clang-format off
77   return GradForUnaryCwise(g, {
78       FDH::Const("c", 2LL),
79       {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
80       {{"x2"}, "Mul", {"x", "two"}, {}, {"dy"}},  // x * 2
81       {{"dx"}, "Mul", {"dy", "x2"}},              // dy * (x * 2)
82   });
83   // clang-format on
84 }
85 REGISTER_OP_GRADIENT("Square", SquareGrad);
86 
SqrtGrad(const AttrSlice & attrs,FunctionDef * g)87 Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
88   // clang-format off
89   return GradForUnaryCwise(g, {
90       {{"y"}, "Sqrt", {"x"}},
91       {{"y_inv"}, "Reciprocal", {"y"}, {}, {"dy"}},
92       FDH::Const("const", 0.5f),
93       {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
94       {{"a"}, "Mul", {"half", "y_inv"}},  // .5 * 1/y
95       {{"dx"}, "Mul", {"dy", "a"}},  // dy * (.5 * 1/y)
96   });
97   // clang-format on
98 }
99 REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
100 
RsqrtGrad(const AttrSlice & attrs,FunctionDef * g)101 Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
102   // clang-format off
103   return GradForUnaryCwise(g, {
104       {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
105       {{"y"}, "Rsqrt", {"x"}},
106       FDH::Const("const", -.5f),
107       {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
108       {{"a"}, "Mul", {"neghalf", "x_inv"}},   // -0.5 * 1/x
109       {{"b"}, "Mul", {"a", "y"}},             // -0.5 * 1/x * y
110       {{"dx"}, "Mul", {"dy", "b"}},           // dy * (1/y * .5)
111   });
112   // clang-format on
113 }
114 REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad);
115 
ExpGrad(const AttrSlice & attrs,FunctionDef * g)116 Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) {
117   // clang-format off
118   return GradForUnaryCwise(g, {
119       {{"y"}, "Exp", {"x"}},
120       {{"dx"}, "Mul", {"dy", "y"}},           // dy * y
121   });
122   // clang-format on
123 }
124 REGISTER_OP_GRADIENT("Exp", ExpGrad);
125 
Expm1Grad(const AttrSlice & attrs,FunctionDef * g)126 Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) {
127   // clang-format off
128   return GradForUnaryCwise(g, {
129       {{"y"}, "Exp", {"x"}},
130       {{"dx"}, "Mul", {"dy", "y"}},           // dy * y
131   });
132   // clang-format on
133 }
134 REGISTER_OP_GRADIENT("Expm1", Expm1Grad);
135 
LogGrad(const AttrSlice & attrs,FunctionDef * g)136 Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
137   // clang-format off
138   return GradForUnaryCwise(g, {
139       {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
140       {{"dx"}, "Mul", {"dy", "x_inv"}},           // dy * 1/x
141   });
142   // clang-format on
143 }
144 REGISTER_OP_GRADIENT("Log", LogGrad);
145 
Log1pGrad(const AttrSlice & attrs,FunctionDef * g)146 Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
147   // clang-format off
148   return GradForUnaryCwise(g, {
149       FDH::Const("const", 1.0f),
150       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
151       {{"a"}, "Add", {"one", "x"}},
152       {{"dx"}, "Div", {"dy", "a"}},           // dy / (1 + x)
153   });
154   // clang-format on
155 }
156 REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
157 
SinhGrad(const AttrSlice & attrs,FunctionDef * g)158 Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
159   // clang-format off
160   return GradForUnaryCwise(g, {
161       {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
162       {{"dx"}, "Mul", {"dy", "cosh"}},  // dy * cosh(x)
163   });
164   // clang-format on
165 }
166 REGISTER_OP_GRADIENT("Sinh", SinhGrad);
167 
CoshGrad(const AttrSlice & attrs,FunctionDef * g)168 Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
169   // clang-format off
170   return GradForUnaryCwise(g, {
171       {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
172       {{"dx"}, "Mul", {"dy", "sinh"}},  // dy * sinh(x)
173   });
174   // clang-format on
175 }
176 REGISTER_OP_GRADIENT("Cosh", CoshGrad);
177 
TanhGrad(const AttrSlice & attrs,FunctionDef * g)178 Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
179   // clang-format off
180   return GradForUnaryCwise(g, {
181       {{"y"}, "Tanh", {"x"}},
182       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
183       FDH::Const("const", 1.0f),
184       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
185       {{"a"}, "Sub", {"one", "y2"}},
186       {{"dx"}, "Mul", {"dy", "a"}},           // dy * (1 - y*y)
187   });
188   // clang-format on
189 }
190 REGISTER_OP_GRADIENT("Tanh", TanhGrad);
191 
AsinhGrad(const AttrSlice & attrs,FunctionDef * g)192 Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) {
193   // clang-format off
194   return GradForUnaryCwise(g, {
195       {{"y"}, "Asinh", {"x"}},
196       {{"cosh"}, "Cosh", {"y"}},
197       {{"dx"}, "Mul", {"dy", "cosh"}},  // dy * cosh(y)
198   });
199   // clang-format on
200 }
201 REGISTER_OP_GRADIENT("Asinh", AsinhGrad);
202 
AcoshGrad(const AttrSlice & attrs,FunctionDef * g)203 Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) {
204   // clang-format off
205   return GradForUnaryCwise(g, {
206       {{"y"}, "Acosh", {"x"}},
207       {{"sinh"}, "Sinh", {"y"}},
208       {{"dx"}, "Mul", {"dy", "sinh"}},  // dy * sinh(y)
209   });
210   // clang-format on
211 }
212 REGISTER_OP_GRADIENT("Acosh", AcoshGrad);
213 
AtanhGrad(const AttrSlice & attrs,FunctionDef * g)214 Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) {
215   // clang-format off
216   return GradForUnaryCwise(g, {
217     {{"x2"}, "Square", {"x"}},
218     FDH::Const("const", 1.0f),
219     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
220     {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
221     {{"inv"}, "Reciprocal", {"a"}},
222     {{"dx"}, "Mul", {"dy", "inv"}}
223   });
224   // clang-format on
225 }
226 REGISTER_OP_GRADIENT("Atanh", AtanhGrad);
227 
SigmoidGrad(const AttrSlice & attrs,FunctionDef * g)228 Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) {
229   // clang-format off
230   return GradForUnaryCwise(g, {
231       {{"y"}, "Sigmoid", {"x"}},
232       FDH::Const("const", 1.0f),
233       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
234       {{"a"}, "Sub", {"one", "y"}, {}, {"dy"}},
235       {{"b"}, "Mul", {"y", "a"}},             // y * (1 - y)
236       {{"dx"}, "Mul", {"dy", "b"}},           // dy * y * (1 - y)
237   });
238   // clang-format on
239 }
240 REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad);
241 
SignGrad(const AttrSlice & attrs,FunctionDef * g)242 Status SignGrad(const AttrSlice& attrs, FunctionDef* g) {
243   // clang-format off
244   return GradForUnaryCwise(g, {
245       {{"s"}, "Shape", {"x"}},
246       FDH::Const("zero", 0.f),
247       {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
248       {{"dx"}, "Fill", {"s", "val"}},
249   });
250   // clang-format on
251 }
252 REGISTER_OP_GRADIENT("Sign", SignGrad);
253 
SinGrad(const AttrSlice & attrs,FunctionDef * g)254 Status SinGrad(const AttrSlice& attrs, FunctionDef* g) {
255   // clang-format off
256   return GradForUnaryCwise(g, {
257       {{"cos"}, "Cos", {"x"}, {}, {"dy"}},
258       {{"dx"}, "Mul", {"dy", "cos"}},  // dy * cos(x)
259   });
260   // clang-format on
261 }
262 REGISTER_OP_GRADIENT("Sin", SinGrad);
263 
CosGrad(const AttrSlice & attrs,FunctionDef * g)264 Status CosGrad(const AttrSlice& attrs, FunctionDef* g) {
265   // clang-format off
266   return GradForUnaryCwise(g, {
267       {{"sin"}, "Sin", {"x"}, {}, {"dy"}},
268       {{"neg"}, "Neg", {"sin"}},
269       {{"dx"}, "Mul", {"dy", "neg"}},  // dy * (-sin(x))
270   });
271   // clang-format on
272 }
273 REGISTER_OP_GRADIENT("Cos", CosGrad);
274 
AcosGrad(const AttrSlice & attrs,FunctionDef * g)275 Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) {
276   // clang-format off
277   return GradForUnaryCwise(g, {
278     {{"x2"}, "Square", {"x"}},
279     FDH::Const("const", 1.0f),
280     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
281     {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
282     {{"b"}, "Sqrt", {"a"}},
283     {{"inv"}, "Reciprocal", {"b"}},
284     {{"neg"}, "Neg", {"inv"}},
285     {{"dx"}, "Mul", {"dy", "neg"}}
286   });
287   // clang-format on
288 }
289 REGISTER_OP_GRADIENT("Acos", AcosGrad);
290 
AsinGrad(const AttrSlice & attrs,FunctionDef * g)291 Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) {
292   // clang-format off
293   return GradForUnaryCwise(g, {
294     {{"x2"}, "Square", {"x"}},
295     FDH::Const("const", 1.0f),
296     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
297     {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
298     {{"b"}, "Sqrt", {"a"}},
299     {{"inv"}, "Reciprocal", {"b"}},
300     {{"dx"}, "Mul", {"dy", "inv"}}
301   });
302   // clang-format on
303 }
304 REGISTER_OP_GRADIENT("Asin", AsinGrad);
305 
AtanGrad(const AttrSlice & attrs,FunctionDef * g)306 Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) {
307   // clang-format off
308   return GradForUnaryCwise(g, {
309     {{"x2"}, "Square", {"x"}},
310     FDH::Const("const", 1.0f),
311     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
312     {{"a"}, "Add", {"one", "x2"}}, // 1 + x^2
313     {{"inv"}, "Reciprocal", {"a"}},
314     {{"dx"}, "Mul", {"dy", "inv"}}
315   });
316   // clang-format on
317 }
318 REGISTER_OP_GRADIENT("Atan", AtanGrad);
319 
TanGrad(const AttrSlice & attrs,FunctionDef * g)320 Status TanGrad(const AttrSlice& attrs, FunctionDef* g) {
321   // clang-format off
322   return GradForUnaryCwise(g, {
323     {{"cosx"}, "Cos", {"x"}},
324     {{"secx"}, "Reciprocal", {"cosx"}},
325     {{"secx2"}, "Square", {"secx"}},
326     {{"dx"}, "Mul", {"dy", "secx2"}}
327   });
328   // clang-format on
329 }
330 REGISTER_OP_GRADIENT("Tan", TanGrad);
331 
RealGrad(const AttrSlice & attrs,FunctionDef * g)332 Status RealGrad(const AttrSlice& attrs, FunctionDef* g) {
333   // clang-format off
334   return GradForUnaryCwise(g, {
335       FDH::Const("zero", 0.f),
336       {{"dx"}, "Complex", {"dy", "zero"}},
337   });
338   // clang-format on
339 }
340 REGISTER_OP_GRADIENT("Real", RealGrad);
341 
ImagGrad(const AttrSlice & attrs,FunctionDef * g)342 Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
343   // clang-format off
344   return GradForUnaryCwise(g, {
345       FDH::Const("zero", 0.f),
346       {{"dx"}, "Complex", {"zero", "dy"}},
347   });
348   // clang-format on
349 }
350 REGISTER_OP_GRADIENT("Imag", ImagGrad);
351 
AngleGrad(const AttrSlice & attrs,FunctionDef * g)352 Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
353   // clang-format off
354   return GradForUnaryCwise(g, {
355       {{"re"}, "Real", {"x"}},
356       {{"im"}, "Imag", {"x"}},
357       {{"z"}, "Complex", {"im", "re"}},
358       {{"z_inv"}, "Reciprocal", {"z"}},
359       {{"neg"}, "Neg", {"z_inv"}},
360       {{"dx"}, "Mul", {"neg", "dy"}},
361   });
362   // clang-format on
363 }
364 REGISTER_OP_GRADIENT("Angle", AngleGrad);
365 
ConjGrad(const AttrSlice & attrs,FunctionDef * g)366 Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
367   // clang-format off
368   return GradForUnaryCwise(g, {
369       {{"dx"}, "Conj", {"dy"}},
370   });
371   // clang-format on
372 }
373 REGISTER_OP_GRADIENT("Conj", ConjGrad);
374 
CastGrad(const AttrSlice & attrs,FunctionDef * g)375 Status CastGrad(const AttrSlice& attrs, FunctionDef* g) {
376   // clang-format off
377   *g = FDH::Define(
378       // Arg defs
379       {"x: SrcT", "dy: DstT"},
380       // Ret val defs
381       {"dx: SrcT"},
382       // Attr defs
383       {{"SrcT: type"}, {"DstT: type"}},
384       // Nodes
385       {{{"dx"}, "Cast", {"dy"}, {{"SrcT", "$DstT"}, {"DstT", "$SrcT"}}}});
386   return Status::OK();
387   // clang-format on
388 }
389 REGISTER_OP_GRADIENT("Cast", CastGrad);
390 
391 // Cwise binary ops
392 //
393 // TODO(zhifengc): This can be arrange as a function in the standard
394 // library.
GradForBinaryCwise(FunctionDef * g,std::vector<FDH::Node> body)395 Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
396   // clang-format off
397   std::vector<FDH::Node> nodes = {
398     {{"sx"}, "Shape", {"x"}},
399     {{"sy"}, "Shape", {"y"}},
400   };
401   nodes.insert(nodes.end(), body.begin(), body.end());
402   std::vector<FDH::Node> reshapes = {
403     {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
404     {{"sum_gx"}, "Sum", {"gx", "rx"}},
405     {{"dx"}, "Reshape", {"sum_gx", "sx"}},
406     {{"sum_gy"}, "Sum", {"gy", "ry"}},
407     {{"dy"}, "Reshape", {"sum_gy", "sy"}},
408   };
409   nodes.insert(nodes.end(), reshapes.begin(), reshapes.end());
410 
411   // clang-format on
412   for (auto& n : nodes) {
413     // "BroadcastGradientArgs" doesn't need any attrs.
414     if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
415       n.attr = {{"T", "$T"}};
416     }
417   }
418   *g = FDH::Define(
419       // Arg defs
420       {"x: T", "y: T", "dz: T"},
421       // Ret val defs
422       {"dx: T", "dy: T"},
423       // Attr defs
424       {{"T: {half, float, double}"}},
425       // Nodes
426       nodes);
427   return Status::OK();
428 }
429 
AddGrad(const AttrSlice & attrs,FunctionDef * g)430 Status AddGrad(const AttrSlice& attrs, FunctionDef* g) {
431   // clang-format off
432   return GradForBinaryCwise(g, {
433       {{"gx"}, "Identity", {"dz"}},
434       {{"gy"}, "Identity", {"dz"}},
435   });
436   // clang-format on
437 }
438 REGISTER_OP_GRADIENT("Add", AddGrad);
439 
SubGrad(const AttrSlice & attrs,FunctionDef * g)440 Status SubGrad(const AttrSlice& attrs, FunctionDef* g) {
441   // clang-format off
442   return GradForBinaryCwise(g, {
443       {{"gx"}, "Identity", {"dz"}},
444       {{"gy"}, "Neg", {"dz"}},          // -dz
445   });
446   // clang-format on
447 }
448 REGISTER_OP_GRADIENT("Sub", SubGrad);
449 
MulGrad(const AttrSlice & attrs,FunctionDef * g)450 Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
451   DataType T;
452   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
453   if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
454     return GradForBinaryCwise(
455         g, {
456                {{"cy"}, "Conj", {"y"}, {}, {"dz"}},
457                {{"gx"}, "Mul", {"dz", "cy"}},  // dz * Conj(y)
458                {{"cx"}, "Conj", {"x"}, {}, {"dz"}},
459                {{"gy"}, "Mul", {"cx", "dz"}},  // Conj(x) * dz
460            });
461   } else {
462     // clang-format off
463     return GradForBinaryCwise(g, {
464         {{"gx"}, "Mul", {"dz", "y"}},  // dz * y
465         {{"gy"}, "Mul", {"x", "dz"}},  // x * dz
466     });
467     // clang-format on
468   }
469 }
470 REGISTER_OP_GRADIENT("Mul", MulGrad);
471 
MulNoNanGrad(const AttrSlice & attrs,FunctionDef * g)472 Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
473   // clang-format off
474   return GradForBinaryCwise(g, {
475       {{"gx"}, "MulNoNan", {"y", "dz"}},  // y * dz
476       {{"gy"}, "MulNoNan", {"x", "dz"}},  // x * dz
477   });
478   // clang-format on
479 }
480 REGISTER_OP_GRADIENT("MulNoNan", MulGrad);
481 
DivGrad(const AttrSlice & attrs,FunctionDef * g)482 Status DivGrad(const AttrSlice& attrs, FunctionDef* g) {
483   // clang-format off
484   return GradForBinaryCwise(g, {
485       {{"gx"}, "Div", {"dz", "y"}},
486       {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
487       {{"y2"}, "Square", {"y"}, {}, {"dz"}},
488       {{"nx_y2"}, "Div", {"nx", "y2"}},
489       {{"gy"}, "Mul", {"dz", "nx_y2"}},  // dz * (- x / y^2)
490   });
491   // clang-format on
492 }
493 REGISTER_OP_GRADIENT("Div", DivGrad);
494 
RealDivGrad(const AttrSlice & attrs,FunctionDef * g)495 Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
496   // clang-format off
497   return GradForBinaryCwise(g, {
498       {{"gx"}, "RealDiv", {"dz", "y"}},
499       {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
500       {{"y2"}, "Square", {"y"}, {}, {"dz"}},
501       {{"nx_y2"}, "RealDiv", {"nx", "y2"}},
502       {{"gy"}, "Mul", {"dz", "nx_y2"}},  // dz * (- x / y^2)
503   });
504   // clang-format on
505 }
506 REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
507 
DivNoNanGrad(const AttrSlice & attrs,FunctionDef * g)508 Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
509   // clang-format off
510   return GradForBinaryCwise(g, {
511       {{"gx"}, "DivNoNan", {"dz", "y"}},
512       {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
513       {{"y2"}, "Square", {"y"}, {}, {"dz"}},
514       {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
515       {{"gy"}, "Mul", {"dz", "nx_y2"}},  // dz * (- x / y^2)
516   });
517   // clang-format on
518 }
519 REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
520 
PowGrad(const AttrSlice & attrs,FunctionDef * g)521 Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
522   // clang-format off
523   std::vector<FDH::Node> nodes = {
524     {{"z"}, "Pow", {"x", "y"}},
525     // dz * y * Pow(x, y - 1)
526     FDH::Const("const_zero", 0.0f),
527     FDH::Const("const_one", 1.0f),
528     {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
529     {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
530     {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
531     {{"t1"}, "Pow", {"x", "t0"}},
532     {{"t2"}, "Mul", {"dz", "y"}},
533     {{"gx"}, "Mul", {"t1", "t2"}},
534     {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
535     {{"zeros"}, "ZerosLike", {"x"}}};
536   // clang-format on
537   std::vector<FDH::Node> log_x_handling;
538   DataType T;
539   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
540   if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
541     // dz * z * (x != 0 ? Log(x) : 0)
542     // clang-format off
543     log_x_handling = {
544       {{"nz_x"}, "NotEqual", {"x", "zero"}},
545       {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}};
546     // clang-format on
547   } else {
548     // dz * z * (x > 0 ? Log(x) : 0)
549     // clang-format off
550     log_x_handling = {
551       {{"pos_x"}, "Greater", {"x", "zero"}},
552       {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}};
553     // clang-format on
554   }
555   nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end());
556   nodes.push_back({{"t4"}, "Mul", {"dz", "z"}});
557   nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}});
558   return GradForBinaryCwise(g, nodes);
559 }
560 REGISTER_OP_GRADIENT("Pow", PowGrad);
561 
XlogyGrad(const AttrSlice & attrs,FunctionDef * g)562 Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
563   // clang-format off
564   return GradForBinaryCwise(g, {
565       {{"zeros"}, "ZerosLike", {"x"}},
566       {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
567       {{"is_zero_cast"}, "Cast", {"is_x_zero"},
568         {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
569       {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
570       {{"xlogygrad"}, "Xdivy", {"x", "y"}},
571       {{"gx"}, "Mul", {"safe_logy", "dz"}},
572       {{"gy"}, "Mul", {"xlogygrad", "dz"}},
573   });
574   // clang-format on
575 }
576 REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
577 
XdivyGrad(const AttrSlice & attrs,FunctionDef * g)578 Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
579   // clang-format off
580   return GradForBinaryCwise(g, {
581       {{"zeros"}, "ZerosLike", {"x"}},
582       {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
583       {{"is_zero_cast"}, "Cast", {"is_x_zero"},
584         {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
585       {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
586       {{"y2"}, "Square", {"y"}},
587       {{"negy2"}, "Neg", {"y2"}},
588       {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
589       {{"gx"}, "Mul", {"safe_divy", "dz"}},
590       {{"gy"}, "Mul", {"xdivygrad", "dz"}},
591   });
592   // clang-format on
593 }
594 REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
595 
SquaredDifferenceGrad(const AttrSlice & attrs,FunctionDef * g)596 Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) {
597   // clang-format off
598   return GradForBinaryCwise(g, {
599       FDH::Const("c", 2LL),
600       {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
601       {{"x_sub_y"}, "Sub", {"x", "y"}},
602       {{"two_x_sub_y"}, "Mul", {"two", "x_sub_y"}},  // 2 * (x - y)
603       {{"gx"}, "Mul", {"two_x_sub_y", "dz"}},
604       {{"gy"}, "Neg", {"gx"}}
605     });
606   // clang-format on
607 }
608 REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad);
609 
MaximumMinimumGradHelper(const string & comparator,const AttrSlice & attrs,FunctionDef * g)610 Status MaximumMinimumGradHelper(const string& comparator,
611                                 const AttrSlice& attrs, FunctionDef* g) {
612   // clang-format off
613   return GradForBinaryCwise(g, {
614       {{"c"}, comparator, {"x", "y"}, {}, {"dz"}},
615       {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
616       {{"gx"}, "Mul", {"dz", "mask"}},
617       {{"gy"}, "Sub", {"dz", "gx"}},
618   });
619   // clang-format on
620 }
621 
MaximumGrad(const AttrSlice & attrs,FunctionDef * g)622 Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) {
623   return MaximumMinimumGradHelper("GreaterEqual", attrs, g);
624 }
625 REGISTER_OP_GRADIENT("Maximum", MaximumGrad);
626 
MinimumGrad(const AttrSlice & attrs,FunctionDef * g)627 Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) {
628   return MaximumMinimumGradHelper("LessEqual", attrs, g);
629 }
630 REGISTER_OP_GRADIENT("Minimum", MinimumGrad);
631 
ComplexGrad(const AttrSlice & attrs,FunctionDef * g)632 Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) {
633   // clang-format off
634   return GradForBinaryCwise(g, {
635       {{"gx"}, "Real", {"dz"}},
636       {{"gy"}, "Imag", {"dz"}},
637   });
638   // clang-format on
639 }
640 REGISTER_OP_GRADIENT("Complex", ComplexGrad);
641 
642 // Cwise ternary ops.
SelectGrad(const AttrSlice & attrs,FunctionDef * g)643 Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) {
644   // clang-format off
645   *g = FDH::Define(
646       {"c:bool", "x:T", "y:T", "dz:T"},
647       {"dc:bool", "dx:T", "dy:T"},
648       {{"T: {half, float, double}"}},
649       {
650         {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}, {"dz"}},
651         {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}, {"dz"}},
652         {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}},
653         {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}},
654       });
655   // clang-format on
656   return Status::OK();
657 }
658 REGISTER_OP_GRADIENT("Select", SelectGrad);
659 
660 // N-ry ops
661 // REGISTER_OP_GRADIENT("AddN", AddNGrad);
662 
663 // Reduction ops
664 //
665 // TODO(zhifengc): This helper is pretty ugly. Do something better.
666 // TODO(zhifengc): This can be arrange as a function in the standard library.
GradForReductionOp(FunctionDef * g,std::vector<FDH::Node> body)667 Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
668   // Shape manipulation nodes.
669 
670   // clang-format off
671   std::vector<FDH::Node> nodes = {
672    {{"x_shape"}, "Shape", {"x"}},
673    {{"x_rank"}, "Rank", {"x"}},
674    {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}},
675    FDH::Const("zero", 0),
676    FDH::Const("one", 1),
677    // stitch_idx0 = Range(0, x_rank, 1)
678    {{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
679     {{"T", DT_INT32}}},
680    {{"y_shape"}, "DynamicStitch",
681     {"stitch_idx0:output:0", "i",
682      "x_shape:output:0", "stitch_val1:output:0"},
683     {{"N", 2}, {"T", DT_INT32}}},
684    {{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
685     {{"T", DT_INT32}}},
686    {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
687   };
688   // clang-format on
689   nodes.insert(nodes.end(), body.begin(), body.end());
690   for (auto& n : nodes) {
691     if (n.attr.empty()) {
692       n.attr = {{"T", "$T"}};
693     }
694   }
695   // "Range" doesn't need any attr.
696   nodes.push_back({{"stitch_idx0"},
697                    "Range",
698                    {"zero:output:0", "x_rank:output:0", "one:output:0"},
699                    {}});
700   *g = FDH::Create("_",
701                    // Input defs
702                    {"x:T", "i:int32", "dy:T"},
703                    // Ret val defs
704                    {"dx:T", "di:int32"},
705                    // Attr defs
706                    {{"T: {half, float, double}"}},
707                    // Nodes
708                    nodes,
709                    // Return values
710                    {{"dx", "dx:output:0"}, {"di", "di:y:0"}});
711   return Status::OK();
712 }
713 
SumGrad(const AttrSlice & attrs,FunctionDef * g)714 Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
715   // clang-format off
716   return GradForReductionOp(g, {
717     {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
718     {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
719   });
720   // clang-format on
721 }
722 REGISTER_OP_GRADIENT("Sum", SumGrad);
723 
MeanGrad(const AttrSlice & attrs,FunctionDef * g)724 Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
725   // clang-format off
726   return GradForReductionOp(g, {
727     {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
728                    {{"T", DT_INT32}}},
729     {{"factor_T"}, "Cast", {"factor:output:0"},
730                    {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
731     {{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
732     {{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
733     {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
734   });
735   // clang-format on
736 }
737 REGISTER_OP_GRADIENT("Mean", MeanGrad);
738 
739 // REGISTER_OP_GRADIENT("Prod", ProdGrad);
740 // REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad);
741 // REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
742 // REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
743 // REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
744 // REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad);
745 // REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
746 // REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
747 // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
748 // REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
749 
MinMaxGradHelper(const string & op,const AttrSlice & attrs,FunctionDef * g)750 Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
751                         FunctionDef* g) {
752   // clang-format off
753   *g = FDH::Define(
754       // Arg defs
755       {"x:T", "i:int32", "dy:T"},
756       // Ret val defs
757       {"dx:T", "di:int32"},
758       // Attr defs
759       {{"T: {half, float, double}"}},
760       {
761         // keep_dims because we need to do x == y, which requires x
762         // and y are broadcastable.
763         {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}},
764         {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}},
765         {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
766         {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}},
767         {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}},
768         {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
769         {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}},
770         {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}},
771         {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
772       });
773   // clang-format on
774   return Status::OK();
775 }
776 
MaxGrad(const AttrSlice & attrs,FunctionDef * g)777 Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) {
778   return MinMaxGradHelper("Max", attrs, g);
779 }
780 REGISTER_OP_GRADIENT("Max", MaxGrad);
781 
MinGrad(const AttrSlice & attrs,FunctionDef * g)782 Status MinGrad(const AttrSlice& attrs, FunctionDef* g) {
783   return MinMaxGradHelper("Min", attrs, g);
784 }
785 REGISTER_OP_GRADIENT("Min", MinGrad);
786 
MatMulGradHelper(FunctionDef * g,const string & opname,const string & attr_adj_x,const string & attr_adj_y,const string & x0,bool ax0,const string & x1,bool ax1,const string & y0,bool ay0,const string & y1,bool ay1)787 static Status MatMulGradHelper(FunctionDef* g, const string& opname,
788                                const string& attr_adj_x,
789                                const string& attr_adj_y, const string& x0,
790                                bool ax0, const string& x1, bool ax1,
791                                const string& y0, bool ay0, const string& y1,
792                                bool ay1) {
793   *g = FDH::Define(
794       // Arg defs
795       {"x: T", "y: T", "dz: T"},
796       // Ret val defs
797       {"dx: T", "dy: T"},
798       // Attr defs
799       {{"T: {half, float, double}"}},
800       // Nodes
801       {
802           {{"dx"},
803            opname,
804            {x0, x1},
805            {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
806           {{"dy"},
807            opname,
808            {y0, y1},
809            {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
810       });
811   return Status::OK();
812 }
813 
MatMulGradCommon(const string & opname,const string & attr_adj_x,const string & attr_adj_y,const AttrSlice & attrs,FunctionDef * g)814 Status MatMulGradCommon(const string& opname, const string& attr_adj_x,
815                         const string& attr_adj_y, const AttrSlice& attrs,
816                         FunctionDef* g) {
817   DataType T;
818   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
819   if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
820     return errors::Unimplemented(
821         "MatMul gradient for complex is not supported yet.");
822   }
823   bool ta;
824   bool tb;
825   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
826   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
827   if (!ta && !tb) {
828     return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
829                             true, "x", true, "dz", false);
830   }
831   if (!ta && tb) {
832     return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
833                             false, "dz", true, "x", false);
834   }
835   if (ta && !tb) {
836     return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
837                             true, "x", false, "dz", false);
838   }
839   CHECK(ta && tb);
840   return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
841                           true, "dz", true, "x", true);
842 }
843 
MatMulGrad(const AttrSlice & attrs,FunctionDef * g)844 Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
845   return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g);
846 }
847 REGISTER_OP_GRADIENT("MatMul", MatMulGrad);
848 
BatchMatMulGrad(const AttrSlice & attrs,FunctionDef * g)849 Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
850   return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g);
851 }
852 REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad);
853 
854 // REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad);
855 
856 // Comparison ops.
857 REGISTER_OP_NO_GRADIENT("Less");
858 REGISTER_OP_NO_GRADIENT("LessEqual");
859 REGISTER_OP_NO_GRADIENT("Greater");
860 REGISTER_OP_NO_GRADIENT("GreaterEqual");
861 REGISTER_OP_NO_GRADIENT("Equal");
862 REGISTER_OP_NO_GRADIENT("NotEqual");
863 
864 // Logical ops.
865 REGISTER_OP_NO_GRADIENT("LogicalAnd");
866 REGISTER_OP_NO_GRADIENT("LogicalOr");
867 REGISTER_OP_NO_GRADIENT("LogicalNot");
868 
869 // Sequence generation ops.
870 REGISTER_OP_NO_GRADIENT("Range");
871 REGISTER_OP_NO_GRADIENT("LinSpace");
872 
873 REGISTER_OP_NO_GRADIENT("Floor");
874 REGISTER_OP_NO_GRADIENT("FloorDiv");
875 REGISTER_OP_NO_GRADIENT("TruncateDiv");
876 
877 }  // end namespace tensorflow
878