• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/array_slice.h"
22 
23 namespace tensorflow {
24 
25 typedef FunctionDefHelper FDH;
26 
27 // Cwise binary ops
GradForUnaryCwise(FunctionDef * g,std::vector<FDH::Node> nodes)28 Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) {
29   for (auto& n : nodes) {
30     if (n.attr.empty()) {
31       n.attr = {{"T", "$T"}};
32     }
33   }
34   *g = FDH::Define(
35       // Arg defs
36       {"x: T", "dy: T"},
37       // Ret val defs
38       {"dx: T"},
39       // Attr defs
40       {{"T: {half, float, double}"}},
41       // Nodes
42       nodes);
43   return OkStatus();
44 }
45 
AbsGrad(const AttrSlice & attrs,FunctionDef * g)46 Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) {
47   // clang-format off
48   return GradForUnaryCwise(g, {
49       {{"sign"}, "Sign", {"x"}, {}, {"dy"}},
50       {{"dx"}, "Mul", {"dy", "sign"}},
51   });
52   // clang-format on
53 }
54 REGISTER_OP_GRADIENT("Abs", AbsGrad);
55 
NegGrad(const AttrSlice & attrs,FunctionDef * g)56 Status NegGrad(const AttrSlice& attrs, FunctionDef* g) {
57   // clang-format off
58   return GradForUnaryCwise(g, {
59       {{"dx"}, "Neg", {"dy"}},
60   });
61   // clang-format on
62 }
63 REGISTER_OP_GRADIENT("Neg", NegGrad);
64 
InvGrad(const AttrSlice & attrs,FunctionDef * g)65 Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
66   // clang-format off
67   return GradForUnaryCwise(g, {
68       {{"y"}, "Reciprocal", {"x"}},
69       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
70       {{"y2_neg"}, "Neg", {"y2"}},
71       {{"dx"}, "Mul", {"dy", "y2_neg"}}
72   });
73   // clang-format on
74 }
75 REGISTER_OP_GRADIENT("Inv", InvGrad);
76 REGISTER_OP_GRADIENT("Reciprocal", InvGrad);
77 
SquareGrad(const AttrSlice & attrs,FunctionDef * g)78 Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
79   // clang-format off
80   return GradForUnaryCwise(g, {
81       FDH::Const("c", int64_t{2}),
82       {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
83       {{"x2"}, "Mul", {"x", "two"}, {}, {"dy"}},  // x * 2
84       {{"dx"}, "Mul", {"dy", "x2"}},              // dy * (x * 2)
85   });
86   // clang-format on
87 }
88 REGISTER_OP_GRADIENT("Square", SquareGrad);
89 
SqrtGrad(const AttrSlice & attrs,FunctionDef * g)90 Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
91   // clang-format off
92   return GradForUnaryCwise(g, {
93       {{"y"}, "Sqrt", {"x"}},
94       {{"y_inv"}, "Reciprocal", {"y"}, {}, {"dy"}},
95       FDH::Const("const", 0.5f),
96       {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
97       {{"a"}, "Mul", {"half", "y_inv"}},  // .5 * 1/y
98       {{"dx"}, "Mul", {"dy", "a"}},  // dy * (.5 * 1/y)
99   });
100   // clang-format on
101 }
102 REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
103 
RsqrtGrad(const AttrSlice & attrs,FunctionDef * g)104 Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
105   // clang-format off
106   return GradForUnaryCwise(g, {
107       {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
108       {{"y"}, "Rsqrt", {"x"}},
109       FDH::Const("const", -.5f),
110       {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
111       {{"a"}, "Mul", {"neghalf", "x_inv"}},   // -0.5 * 1/x
112       {{"b"}, "Mul", {"a", "y"}},             // -0.5 * 1/x * y
113       {{"dx"}, "Mul", {"dy", "b"}},           // dy * (1/y * .5)
114   });
115   // clang-format on
116 }
117 REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad);
118 
ExpGrad(const AttrSlice & attrs,FunctionDef * g)119 Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) {
120   // clang-format off
121   return GradForUnaryCwise(g, {
122       {{"y"}, "Exp", {"x"}},
123       {{"dx"}, "Mul", {"dy", "y"}},           // dy * y
124   });
125   // clang-format on
126 }
127 REGISTER_OP_GRADIENT("Exp", ExpGrad);
128 
Expm1Grad(const AttrSlice & attrs,FunctionDef * g)129 Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) {
130   // clang-format off
131   return GradForUnaryCwise(g, {
132       {{"y"}, "Exp", {"x"}},
133       {{"dx"}, "Mul", {"dy", "y"}},           // dy * y
134   });
135   // clang-format on
136 }
137 REGISTER_OP_GRADIENT("Expm1", Expm1Grad);
138 
LogGrad(const AttrSlice & attrs,FunctionDef * g)139 Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
140   // clang-format off
141   return GradForUnaryCwise(g, {
142       {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
143       {{"dx"}, "Mul", {"dy", "x_inv"}},           // dy * 1/x
144   });
145   // clang-format on
146 }
147 REGISTER_OP_GRADIENT("Log", LogGrad);
148 
Log1pGrad(const AttrSlice & attrs,FunctionDef * g)149 Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
150   // clang-format off
151   return GradForUnaryCwise(g, {
152       FDH::Const("const", 1.0f),
153       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
154       {{"a"}, "Add", {"one", "x"}},
155       {{"dx"}, "Div", {"dy", "a"}},           // dy / (1 + x)
156   });
157   // clang-format on
158 }
159 REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
160 
SinhGrad(const AttrSlice & attrs,FunctionDef * g)161 Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
162   // clang-format off
163   return GradForUnaryCwise(g, {
164       {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
165       {{"dx"}, "Mul", {"dy", "cosh"}},  // dy * cosh(x)
166   });
167   // clang-format on
168 }
169 REGISTER_OP_GRADIENT("Sinh", SinhGrad);
170 
CoshGrad(const AttrSlice & attrs,FunctionDef * g)171 Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
172   // clang-format off
173   return GradForUnaryCwise(g, {
174       {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
175       {{"dx"}, "Mul", {"dy", "sinh"}},  // dy * sinh(x)
176   });
177   // clang-format on
178 }
179 REGISTER_OP_GRADIENT("Cosh", CoshGrad);
180 
TanhGrad(const AttrSlice & attrs,FunctionDef * g)181 Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
182   // clang-format off
183   return GradForUnaryCwise(g, {
184       {{"y"}, "Tanh", {"x"}},
185       {{"y2"}, "Square", {"y"}, {}, {"dy"}},
186       FDH::Const("const", 1.0f),
187       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
188       {{"a"}, "Sub", {"one", "y2"}},
189       {{"dx"}, "Mul", {"dy", "a"}},           // dy * (1 - y*y)
190   });
191   // clang-format on
192 }
193 REGISTER_OP_GRADIENT("Tanh", TanhGrad);
194 
AsinhGrad(const AttrSlice & attrs,FunctionDef * g)195 Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) {
196   // clang-format off
197   return GradForUnaryCwise(g, {
198       {{"y"}, "Asinh", {"x"}},
199       {{"cosh"}, "Cosh", {"y"}},
200       {{"dx"}, "Mul", {"dy", "cosh"}},  // dy * cosh(y)
201   });
202   // clang-format on
203 }
204 REGISTER_OP_GRADIENT("Asinh", AsinhGrad);
205 
AcoshGrad(const AttrSlice & attrs,FunctionDef * g)206 Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) {
207   // clang-format off
208   return GradForUnaryCwise(g, {
209       {{"y"}, "Acosh", {"x"}},
210       {{"sinh"}, "Sinh", {"y"}},
211       {{"dx"}, "Mul", {"dy", "sinh"}},  // dy * sinh(y)
212   });
213   // clang-format on
214 }
215 REGISTER_OP_GRADIENT("Acosh", AcoshGrad);
216 
AtanhGrad(const AttrSlice & attrs,FunctionDef * g)217 Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) {
218   // clang-format off
219   return GradForUnaryCwise(g, {
220     {{"x2"}, "Square", {"x"}},
221     FDH::Const("const", 1.0f),
222     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
223     {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
224     {{"inv"}, "Reciprocal", {"a"}},
225     {{"dx"}, "Mul", {"dy", "inv"}}
226   });
227   // clang-format on
228 }
229 REGISTER_OP_GRADIENT("Atanh", AtanhGrad);
230 
SigmoidGrad(const AttrSlice & attrs,FunctionDef * g)231 Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) {
232   // clang-format off
233   return GradForUnaryCwise(g, {
234       {{"y"}, "Sigmoid", {"x"}},
235       FDH::Const("const", 1.0f),
236       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
237       {{"a"}, "Sub", {"one", "y"}, {}, {"dy"}},
238       {{"b"}, "Mul", {"y", "a"}},             // y * (1 - y)
239       {{"dx"}, "Mul", {"dy", "b"}},           // dy * y * (1 - y)
240   });
241   // clang-format on
242 }
243 REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad);
244 
SignGrad(const AttrSlice & attrs,FunctionDef * g)245 Status SignGrad(const AttrSlice& attrs, FunctionDef* g) {
246   // clang-format off
247   return GradForUnaryCwise(g, {
248       {{"s"}, "Shape", {"x"}},
249       FDH::Const("zero", 0.f),
250       {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
251       {{"dx"}, "Fill", {"s", "val"}},
252   });
253   // clang-format on
254 }
255 REGISTER_OP_GRADIENT("Sign", SignGrad);
256 
SinGrad(const AttrSlice & attrs,FunctionDef * g)257 Status SinGrad(const AttrSlice& attrs, FunctionDef* g) {
258   // clang-format off
259   return GradForUnaryCwise(g, {
260       {{"cos"}, "Cos", {"x"}, {}, {"dy"}},
261       {{"dx"}, "Mul", {"dy", "cos"}},  // dy * cos(x)
262   });
263   // clang-format on
264 }
265 REGISTER_OP_GRADIENT("Sin", SinGrad);
266 
CosGrad(const AttrSlice & attrs,FunctionDef * g)267 Status CosGrad(const AttrSlice& attrs, FunctionDef* g) {
268   // clang-format off
269   return GradForUnaryCwise(g, {
270       {{"sin"}, "Sin", {"x"}, {}, {"dy"}},
271       {{"neg"}, "Neg", {"sin"}},
272       {{"dx"}, "Mul", {"dy", "neg"}},  // dy * (-sin(x))
273   });
274   // clang-format on
275 }
276 REGISTER_OP_GRADIENT("Cos", CosGrad);
277 
AcosGrad(const AttrSlice & attrs,FunctionDef * g)278 Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) {
279   // clang-format off
280   return GradForUnaryCwise(g, {
281     {{"x2"}, "Square", {"x"}},
282     FDH::Const("const", 1.0f),
283     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
284     {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
285     {{"b"}, "Sqrt", {"a"}},
286     {{"inv"}, "Reciprocal", {"b"}},
287     {{"neg"}, "Neg", {"inv"}},
288     {{"dx"}, "Mul", {"dy", "neg"}}
289   });
290   // clang-format on
291 }
292 REGISTER_OP_GRADIENT("Acos", AcosGrad);
293 
AsinGrad(const AttrSlice & attrs,FunctionDef * g)294 Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) {
295   // clang-format off
296   return GradForUnaryCwise(g, {
297     {{"x2"}, "Square", {"x"}},
298     FDH::Const("const", 1.0f),
299     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
300     {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
301     {{"b"}, "Sqrt", {"a"}},
302     {{"inv"}, "Reciprocal", {"b"}},
303     {{"dx"}, "Mul", {"dy", "inv"}}
304   });
305   // clang-format on
306 }
307 REGISTER_OP_GRADIENT("Asin", AsinGrad);
308 
AtanGrad(const AttrSlice & attrs,FunctionDef * g)309 Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) {
310   // clang-format off
311   return GradForUnaryCwise(g, {
312     {{"x2"}, "Square", {"x"}},
313     FDH::Const("const", 1.0f),
314     {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
315     {{"a"}, "Add", {"one", "x2"}}, // 1 + x^2
316     {{"inv"}, "Reciprocal", {"a"}},
317     {{"dx"}, "Mul", {"dy", "inv"}}
318   });
319   // clang-format on
320 }
321 REGISTER_OP_GRADIENT("Atan", AtanGrad);
322 
TanGrad(const AttrSlice & attrs,FunctionDef * g)323 Status TanGrad(const AttrSlice& attrs, FunctionDef* g) {
324   // clang-format off
325   return GradForUnaryCwise(g, {
326     {{"cosx"}, "Cos", {"x"}},
327     {{"secx"}, "Reciprocal", {"cosx"}},
328     {{"secx2"}, "Square", {"secx"}},
329     {{"dx"}, "Mul", {"dy", "secx2"}}
330   });
331   // clang-format on
332 }
333 REGISTER_OP_GRADIENT("Tan", TanGrad);
334 
RealGrad(const AttrSlice & attrs,FunctionDef * g)335 Status RealGrad(const AttrSlice& attrs, FunctionDef* g) {
336   // clang-format off
337   return GradForUnaryCwise(g, {
338       FDH::Const("zero", 0.f),
339       {{"dx"}, "Complex", {"dy", "zero"}},
340   });
341   // clang-format on
342 }
343 REGISTER_OP_GRADIENT("Real", RealGrad);
344 
ImagGrad(const AttrSlice & attrs,FunctionDef * g)345 Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
346   // clang-format off
347   return GradForUnaryCwise(g, {
348       FDH::Const("zero", 0.f),
349       {{"dx"}, "Complex", {"zero", "dy"}},
350   });
351   // clang-format on
352 }
353 REGISTER_OP_GRADIENT("Imag", ImagGrad);
354 
AngleGrad(const AttrSlice & attrs,FunctionDef * g)355 Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
356   // clang-format off
357   return GradForUnaryCwise(g, {
358       {{"re"}, "Real", {"x"}},
359       {{"im"}, "Imag", {"x"}},
360       {{"z"}, "Complex", {"im", "re"}},
361       {{"z_inv"}, "Reciprocal", {"z"}},
362       {{"neg"}, "Neg", {"z_inv"}},
363       {{"dx"}, "Mul", {"neg", "dy"}},
364   });
365   // clang-format on
366 }
367 REGISTER_OP_GRADIENT("Angle", AngleGrad);
368 
ConjGrad(const AttrSlice & attrs,FunctionDef * g)369 Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
370   // clang-format off
371   return GradForUnaryCwise(g, {
372       {{"dx"}, "Conj", {"dy"}},
373   });
374   // clang-format on
375 }
376 REGISTER_OP_GRADIENT("Conj", ConjGrad);
377 
CastGrad(const AttrSlice & attrs,FunctionDef * g)378 Status CastGrad(const AttrSlice& attrs, FunctionDef* g) {
379   // clang-format off
380   *g = FDH::Define(
381       // Arg defs
382       {"x: SrcT", "dy: DstT"},
383       // Ret val defs
384       {"dx: SrcT"},
385       // Attr defs
386       {{"SrcT: type"}, {"DstT: type"}},
387       // Nodes
388       {{{"dx"}, "Cast", {"dy"}, {{"SrcT", "$DstT"}, {"DstT", "$SrcT"}}}});
389   return OkStatus();
390   // clang-format on
391 }
392 REGISTER_OP_GRADIENT("Cast", CastGrad);
393 
394 // Cwise binary ops
395 //
396 // TODO(zhifengc): This can be arrange as a function in the standard
397 // library.
GradForBinaryCwise(FunctionDef * g,std::vector<FDH::Node> body)398 Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
399   // clang-format off
400   std::vector<FDH::Node> nodes = {
401     {{"sx"}, "Shape", {"x"}},
402     {{"sy"}, "Shape", {"y"}},
403   };
404   nodes.insert(nodes.end(), body.begin(), body.end());
405   std::vector<FDH::Node> reshapes = {
406     {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
407     {{"sum_gx"}, "Sum", {"gx", "rx"}},
408     {{"dx"}, "Reshape", {"sum_gx", "sx"}},
409     {{"sum_gy"}, "Sum", {"gy", "ry"}},
410     {{"dy"}, "Reshape", {"sum_gy", "sy"}},
411   };
412   nodes.insert(nodes.end(), reshapes.begin(), reshapes.end());
413 
414   // clang-format on
415   for (auto& n : nodes) {
416     // "BroadcastGradientArgs" doesn't need any attrs.
417     if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
418       n.attr = {{"T", "$T"}};
419     }
420   }
421   *g = FDH::Define(
422       // Arg defs
423       {"x: T", "y: T", "dz: T"},
424       // Ret val defs
425       {"dx: T", "dy: T"},
426       // Attr defs
427       {{"T: {half, float, double}"}},
428       // Nodes
429       nodes);
430   return OkStatus();
431 }
432 
AddGrad(const AttrSlice & attrs,FunctionDef * g)433 Status AddGrad(const AttrSlice& attrs, FunctionDef* g) {
434   // clang-format off
435   return GradForBinaryCwise(g, {
436       {{"gx"}, "Identity", {"dz"}},
437       {{"gy"}, "Identity", {"dz"}},
438   });
439   // clang-format on
440 }
441 REGISTER_OP_GRADIENT("Add", AddGrad);
442 REGISTER_OP_GRADIENT("AddV2", AddGrad);
443 
SubGrad(const AttrSlice & attrs,FunctionDef * g)444 Status SubGrad(const AttrSlice& attrs, FunctionDef* g) {
445   // clang-format off
446   return GradForBinaryCwise(g, {
447       {{"gx"}, "Identity", {"dz"}},
448       {{"gy"}, "Neg", {"dz"}},          // -dz
449   });
450   // clang-format on
451 }
452 REGISTER_OP_GRADIENT("Sub", SubGrad);
453 
MulGrad(const AttrSlice & attrs,FunctionDef * g)454 Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
455   DataType T;
456   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
457   if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
458     return GradForBinaryCwise(
459         g, {
460                {{"cy"}, "Conj", {"y"}, {}, {"dz"}},
461                {{"gx"}, "Mul", {"dz", "cy"}},  // dz * Conj(y)
462                {{"cx"}, "Conj", {"x"}, {}, {"dz"}},
463                {{"gy"}, "Mul", {"cx", "dz"}},  // Conj(x) * dz
464            });
465   } else {
466     // clang-format off
467     return GradForBinaryCwise(g, {
468         {{"gx"}, "Mul", {"dz", "y"}},  // dz * y
469         {{"gy"}, "Mul", {"x", "dz"}},  // x * dz
470     });
471     // clang-format on
472   }
473 }
474 REGISTER_OP_GRADIENT("Mul", MulGrad);
475 
MulNoNanGrad(const AttrSlice & attrs,FunctionDef * g)476 Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
477   // clang-format off
478   return GradForBinaryCwise(g, {
479       {{"gx"}, "MulNoNan", {"y", "dz"}},  // y * dz
480       {{"gy"}, "MulNoNan", {"x", "dz"}},  // x * dz
481   });
482   // clang-format on
483 }
484 REGISTER_OP_GRADIENT("MulNoNan", MulGrad);
485 
DivGrad(const AttrSlice & attrs,FunctionDef * g)486 Status DivGrad(const AttrSlice& attrs, FunctionDef* g) {
487   // clang-format off
488   return GradForBinaryCwise(g, {
489       {{"gx"}, "Div", {"dz", "y"}},
490       {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
491       {{"y2"}, "Square", {"y"}, {}, {"dz"}},
492       {{"nx_y2"}, "Div", {"nx", "y2"}},
493       {{"gy"}, "Mul", {"dz", "nx_y2"}},  // dz * (- x / y^2)
494   });
495   // clang-format on
496 }
497 REGISTER_OP_GRADIENT("Div", DivGrad);
498 
RealDivGrad(const AttrSlice & attrs,FunctionDef * g)499 Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
500   // clang-format off
501   return GradForBinaryCwise(g, {
502       {{"gx"}, "RealDiv", {"dz", "y"}},
503       {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
504       {{"y2"}, "Square", {"y"}, {}, {"dz"}},
505       {{"nx_y2"}, "RealDiv", {"nx", "y2"}},
506       {{"gy"}, "Mul", {"dz", "nx_y2"}},  // dz * (- x / y^2)
507   });
508   // clang-format on
509 }
510 REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
511 
DivNoNanGrad(const AttrSlice & attrs,FunctionDef * g)512 Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
513   // clang-format off
514   return GradForBinaryCwise(g, {
515       {{"gx"}, "DivNoNan", {"dz", "y"}},
516       {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
517       {{"y2"}, "Square", {"y"}, {}, {"dz"}},
518       {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
519       {{"gy"}, "Mul", {"dz", "nx_y2"}},  // dz * (- x / y^2)
520   });
521   // clang-format on
522 }
523 REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
524 
PowGrad(const AttrSlice & attrs,FunctionDef * g)525 Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
526   // clang-format off
527   std::vector<FDH::Node> nodes = {
528     {{"z"}, "Pow", {"x", "y"}},
529     // dz * y * Pow(x, y - 1)
530     FDH::Const("const_zero", 0.0f),
531     FDH::Const("const_one", 1.0f),
532     {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
533     {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
534     {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
535     {{"t1"}, "Pow", {"x", "t0"}},
536     {{"t2"}, "Mul", {"dz", "y"}},
537     {{"gx"}, "Mul", {"t1", "t2"}},
538     {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
539     {{"zeros"}, "ZerosLike", {"x"}}};
540   // clang-format on
541   std::vector<FDH::Node> log_x_handling;
542   DataType T;
543   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
544   if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
545     // dz * z * (x != 0 ? Log(x) : 0)
546     // clang-format off
547     log_x_handling = {
548       {{"nz_x"}, "NotEqual", {"x", "zero"}},
549       {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}};
550     // clang-format on
551   } else {
552     // dz * z * (x > 0 ? Log(x) : 0)
553     // clang-format off
554     log_x_handling = {
555       {{"pos_x"}, "Greater", {"x", "zero"}},
556       {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}};
557     // clang-format on
558   }
559   nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end());
560   nodes.push_back({{"t4"}, "Mul", {"dz", "z"}});
561   nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}});
562   return GradForBinaryCwise(g, nodes);
563 }
564 REGISTER_OP_GRADIENT("Pow", PowGrad);
565 
XlogyGrad(const AttrSlice & attrs,FunctionDef * g)566 Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
567   // clang-format off
568   return GradForBinaryCwise(g, {
569       {{"zeros"}, "ZerosLike", {"x"}},
570       {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
571       {{"is_zero_cast"}, "Cast", {"is_x_zero"},
572         {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
573       {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
574       {{"xlogygrad"}, "Xdivy", {"x", "y"}},
575       {{"gx"}, "Mul", {"safe_logy", "dz"}},
576       {{"gy"}, "Mul", {"xlogygrad", "dz"}},
577   });
578   // clang-format on
579 }
580 REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
581 
Xlog1pyGrad(const AttrSlice & attrs,FunctionDef * g)582 Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) {
583   // clang-format off
584   return GradForBinaryCwise(g, {
585       FDH::Const("const", 1.0f),
586       {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
587       {{"zeros"}, "ZerosLike", {"x"}},
588       {{"yp1"}, "Add", {"y", "one"}},
589       {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
590       {{"is_zero_cast"}, "Cast", {"is_x_zero"},
591         {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
592       {{"safe_log1py"}, "Xlog1py", {"is_zero_cast", "y"}},
593       {{"xlog1pygrad"}, "Xdivy", {"x", "yp1"}},
594       {{"gx"}, "Mul", {"safe_log1py", "dz"}},
595       {{"gy"}, "Mul", {"xlog1pygrad", "dz"}},
596   });
597   // clang-format on
598 }
599 REGISTER_OP_GRADIENT("Xlog1py", Xlog1pyGrad);
600 
XdivyGrad(const AttrSlice & attrs,FunctionDef * g)601 Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
602   // clang-format off
603   return GradForBinaryCwise(g, {
604       {{"zeros"}, "ZerosLike", {"x"}},
605       {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
606       {{"is_zero_cast"}, "Cast", {"is_x_zero"},
607         {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
608       {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
609       {{"y2"}, "Square", {"y"}},
610       {{"negy2"}, "Neg", {"y2"}},
611       {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
612       {{"gx"}, "Mul", {"safe_divy", "dz"}},
613       {{"gy"}, "Mul", {"xdivygrad", "dz"}},
614   });
615   // clang-format on
616 }
617 REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
618 
SquaredDifferenceGrad(const AttrSlice & attrs,FunctionDef * g)619 Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) {
620   // clang-format off
621   return GradForBinaryCwise(g, {
622       FDH::Const("c", int64_t{2}),
623       {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
624       {{"x_sub_y"}, "Sub", {"x", "y"}},
625       {{"two_x_sub_y"}, "Mul", {"two", "x_sub_y"}},  // 2 * (x - y)
626       {{"gx"}, "Mul", {"two_x_sub_y", "dz"}},
627       {{"gy"}, "Neg", {"gx"}}
628     });
629   // clang-format on
630 }
631 REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad);
632 
MaximumMinimumGradHelper(const string & comparator,const AttrSlice & attrs,FunctionDef * g)633 Status MaximumMinimumGradHelper(const string& comparator,
634                                 const AttrSlice& attrs, FunctionDef* g) {
635   // clang-format off
636   return GradForBinaryCwise(g, {
637       {{"c"}, comparator, {"x", "y"}, {}, {"dz"}},
638       {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
639       {{"gx"}, "Mul", {"dz", "mask"}},
640       {{"gy"}, "Sub", {"dz", "gx"}},
641   });
642   // clang-format on
643 }
644 
MaximumGrad(const AttrSlice & attrs,FunctionDef * g)645 Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) {
646   return MaximumMinimumGradHelper("GreaterEqual", attrs, g);
647 }
648 REGISTER_OP_GRADIENT("Maximum", MaximumGrad);
649 
MinimumGrad(const AttrSlice & attrs,FunctionDef * g)650 Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) {
651   return MaximumMinimumGradHelper("LessEqual", attrs, g);
652 }
653 REGISTER_OP_GRADIENT("Minimum", MinimumGrad);
654 
ComplexGrad(const AttrSlice & attrs,FunctionDef * g)655 Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) {
656   // clang-format off
657   return GradForBinaryCwise(g, {
658       {{"gx"}, "Real", {"dz"}},
659       {{"gy"}, "Imag", {"dz"}},
660   });
661   // clang-format on
662 }
663 REGISTER_OP_GRADIENT("Complex", ComplexGrad);
664 
665 // Cwise ternary ops.
SelectGrad(const AttrSlice & attrs,FunctionDef * g)666 Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) {
667   // clang-format off
668   *g = FDH::Define(
669       {"c:bool", "x:T", "y:T", "dz:T"},
670       {"dc:bool", "dx:T", "dy:T"},
671       {{"T: {half, float, double}"}},
672       {
673         {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}, {"dz"}},
674         {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}, {"dz"}},
675         {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}},
676         {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}},
677       });
678   // clang-format on
679   return OkStatus();
680 }
681 REGISTER_OP_GRADIENT("Select", SelectGrad);
682 
683 // N-ry ops
684 // REGISTER_OP_GRADIENT("AddN", AddNGrad);
685 
686 // Reduction ops
687 //
688 // TODO(zhifengc): This helper is pretty ugly. Do something better.
689 // TODO(zhifengc): This can be arrange as a function in the standard library.
GradForReductionOp(FunctionDef * g,std::vector<FDH::Node> body)690 Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
691   // Shape manipulation nodes.
692 
693   // clang-format off
694   std::vector<FDH::Node> nodes = {
695    {{"x_shape"}, "Shape", {"x"}},
696    {{"x_rank"}, "Rank", {"x"}},
697    {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}},
698    FDH::Const("zero", 0),
699    FDH::Const("one", 1),
700    // stitch_idx0 = Range(0, x_rank, 1)
701    {{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
702     {{"T", DT_INT32}}},
703    {{"y_shape"}, "DynamicStitch",
704     {"stitch_idx0:output:0", "i",
705      "x_shape:output:0", "stitch_val1:output:0"},
706     {{"N", 2}, {"T", DT_INT32}}},
707    {{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
708     {{"T", DT_INT32}}},
709    {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
710   };
711   // clang-format on
712   nodes.insert(nodes.end(), body.begin(), body.end());
713   for (auto& n : nodes) {
714     if (n.attr.empty()) {
715       n.attr = {{"T", "$T"}};
716     }
717   }
718   // "Range" doesn't need any attr.
719   nodes.push_back({{"stitch_idx0"},
720                    "Range",
721                    {"zero:output:0", "x_rank:output:0", "one:output:0"},
722                    {}});
723   *g = FDH::Create("_",
724                    // Input defs
725                    {"x:T", "i:int32", "dy:T"},
726                    // Ret val defs
727                    {"dx:T", "di:int32"},
728                    // Attr defs
729                    {{"T: {half, float, double}"}},
730                    // Nodes
731                    nodes,
732                    // Return values
733                    {{"dx", "dx:output:0"}, {"di", "di:y:0"}});
734   return OkStatus();
735 }
736 
SumGrad(const AttrSlice & attrs,FunctionDef * g)737 Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
738   // clang-format off
739   return GradForReductionOp(g, {
740     {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
741     {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
742   });
743   // clang-format on
744 }
745 REGISTER_OP_GRADIENT("Sum", SumGrad);
746 
MeanGrad(const AttrSlice & attrs,FunctionDef * g)747 Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
748   // clang-format off
749   return GradForReductionOp(g, {
750     {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
751                    {{"T", DT_INT32}}},
752     {{"factor_T"}, "Cast", {"factor:output:0"},
753                    {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
754     {{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
755     {{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
756     {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
757   });
758   // clang-format on
759 }
760 REGISTER_OP_GRADIENT("Mean", MeanGrad);
761 
762 // REGISTER_OP_GRADIENT("Prod", ProdGrad);
763 // REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad);
764 // REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
765 // REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
766 // REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
767 // REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad);
768 // REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
769 // REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
770 // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
771 // REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
772 
MinMaxGradHelper(const string & op,const AttrSlice & attrs,FunctionDef * g)773 Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
774                         FunctionDef* g) {
775   // clang-format off
776   *g = FDH::Define(
777       // Arg defs
778       {"x:T", "i:int32", "dy:T"},
779       // Ret val defs
780       {"dx:T", "di:int32"},
781       // Attr defs
782       {{"T: {half, float, double}"}},
783       {
784         // keep_dims because we need to do x == y, which requires x
785         // and y are broadcastable.
786         {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}},
787         {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}},
788         {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
789         {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}},
790         {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}},
791         {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
792         {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}},
793         {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}},
794         {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
795       });
796   // clang-format on
797   return OkStatus();
798 }
799 
MaxGrad(const AttrSlice & attrs,FunctionDef * g)800 Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) {
801   return MinMaxGradHelper("Max", attrs, g);
802 }
803 REGISTER_OP_GRADIENT("Max", MaxGrad);
804 
MinGrad(const AttrSlice & attrs,FunctionDef * g)805 Status MinGrad(const AttrSlice& attrs, FunctionDef* g) {
806   return MinMaxGradHelper("Min", attrs, g);
807 }
808 REGISTER_OP_GRADIENT("Min", MinGrad);
809 
MatMulGradHelper(FunctionDef * g,const string & opname,const string & attr_adj_x,const string & attr_adj_y,const string & x0,bool ax0,const string & x1,bool ax1,const string & y0,bool ay0,const string & y1,bool ay1,bool enable_broadcasting)810 static Status MatMulGradHelper(FunctionDef* g, const string& opname,
811                                const string& attr_adj_x,
812                                const string& attr_adj_y, const string& x0,
813                                bool ax0, const string& x1, bool ax1,
814                                const string& y0, bool ay0, const string& y1,
815                                bool ay1, bool enable_broadcasting) {
816   // The final outputs are "dx" and "dy". If we're broadcasting compute
817   // intermediate nodes for now.
818   std::vector<FDH::Node> nodes = {
819       {{(enable_broadcasting ? "gx" : "dx")},
820        opname,
821        {x0, x1},
822        {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
823       {{(enable_broadcasting ? "gy" : "dy")},
824        opname,
825        {y0, y1},
826        {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
827   };
828   // TODO(anudhyan): Figure out a way to inspect the static shapes of "x" and
829   // "y". If they have the same batch dimensions, then we can omit adding the
830   // broadcasting-specific ops.
831   if (enable_broadcasting) {
832     std::vector<FDH::Node> unbroadcast_gradients = {
833         FDH::Const<int32>("zero", gtl::ArraySlice<int32>{0}),
834         FDH::Const<int32>("one", gtl::ArraySlice<int32>{1}),
835         FDH::Const<int32>("minustwo", gtl::ArraySlice<int32>{-2}),
836         // Compute the batch shapes of the inputs (all but last two dims).
837         {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
838         {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
839         {{"batch_sx"},
840          "StridedSlice",
841          {"sx", "zero", "minustwo", "one"},
842          {{"T", DT_INT32}, {"Index", DT_INT32}}},
843         {{"batch_sy"},
844          "StridedSlice",
845          {"sy", "zero", "minustwo", "one"},
846          {{"T", DT_INT32}, {"Index", DT_INT32}}},
847         // Sum along dimensions that the inputs were broadcasted across.
848         {{"rx", "ry"}, "BroadcastGradientArgs", {"batch_sx", "batch_sy"}},
849         {{"sum_gx"}, "Sum", {"gx", "rx"}, {{"T", "$T"}}},
850         {{"sum_gy"}, "Sum", {"gy", "ry"}, {{"T", "$T"}}},
851         {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
852         {{"dy"}, "Reshape", {"sum_gy", "sy"}, {{"T", "$T"}}}};
853     nodes.insert(nodes.end(), unbroadcast_gradients.begin(),
854                  unbroadcast_gradients.end());
855   }
856   *g = FDH::Define(
857       // Arg defs
858       {"x: T", "y: T", "dz: T"},
859       // Ret val defs
860       {"dx: T", "dy: T"},
861       // Attr defs
862       {{"T: {half, float, double}"}},
863       // Nodes
864       nodes);
865   return OkStatus();
866 }
867 
MatMulGradCommon(const string & opname,const string & attr_adj_x,const string & attr_adj_y,const AttrSlice & attrs,FunctionDef * g,bool enable_broadcasting)868 Status MatMulGradCommon(const string& opname, const string& attr_adj_x,
869                         const string& attr_adj_y, const AttrSlice& attrs,
870                         FunctionDef* g, bool enable_broadcasting) {
871   DataType T;
872   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
873   if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
874     return errors::Unimplemented(
875         "MatMul gradient for complex is not supported yet.");
876   }
877   bool ta;
878   bool tb;
879   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
880   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
881   if (!ta && !tb) {
882     return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
883                             true, "x", true, "dz", false, enable_broadcasting);
884   }
885   if (!ta && tb) {
886     return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
887                             false, "dz", true, "x", false, enable_broadcasting);
888   }
889   if (ta && !tb) {
890     return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
891                             true, "x", false, "dz", false, enable_broadcasting);
892   }
893   CHECK(ta && tb);
894   return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
895                           true, "dz", true, "x", true, enable_broadcasting);
896 }
897 
MatMulGrad(const AttrSlice & attrs,FunctionDef * g)898 Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
899   return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g,
900                           false /* enable_broadcasting */);
901 }
902 REGISTER_OP_GRADIENT("MatMul", MatMulGrad);
903 
BatchMatMulGrad(const AttrSlice & attrs,FunctionDef * g)904 Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
905   return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g,
906                           false /* enable_broadcasting */);
907 }
908 REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad);
909 
BatchMatMulV2Grad(const AttrSlice & attrs,FunctionDef * g)910 Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) {
911   return MatMulGradCommon("BatchMatMulV2", "adj_x", "adj_y", attrs, g,
912                           true /* enable_broadcasting */);
913 }
914 REGISTER_OP_GRADIENT("BatchMatMulV2", BatchMatMulV2Grad);
915 
916 // REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad);
917 
918 // Comparison ops.
919 REGISTER_OP_NO_GRADIENT("Less");
920 REGISTER_OP_NO_GRADIENT("LessEqual");
921 REGISTER_OP_NO_GRADIENT("Greater");
922 REGISTER_OP_NO_GRADIENT("GreaterEqual");
923 REGISTER_OP_NO_GRADIENT("Equal");
924 REGISTER_OP_NO_GRADIENT("NotEqual");
925 
926 // Logical ops.
927 REGISTER_OP_NO_GRADIENT("LogicalAnd");
928 REGISTER_OP_NO_GRADIENT("LogicalOr");
929 REGISTER_OP_NO_GRADIENT("LogicalNot");
930 
931 // Sequence generation ops.
932 REGISTER_OP_NO_GRADIENT("Range");
933 REGISTER_OP_NO_GRADIENT("LinSpace");
934 
935 REGISTER_OP_NO_GRADIENT("Floor");
936 REGISTER_OP_NO_GRADIENT("FloorDiv");
937 REGISTER_OP_NO_GRADIENT("TruncateDiv");
938 
939 }  // end namespace tensorflow
940