1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "tensorflow/core/framework/function.h"
19 #include "tensorflow/core/framework/types.pb.h"
20 #include "tensorflow/core/lib/core/errors.h"
21 #include "tensorflow/core/lib/gtl/array_slice.h"
22
23 namespace tensorflow {
24
25 typedef FunctionDefHelper FDH;
26
27 // Cwise binary ops
GradForUnaryCwise(FunctionDef * g,std::vector<FDH::Node> nodes)28 Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) {
29 for (auto& n : nodes) {
30 if (n.attr.empty()) {
31 n.attr = {{"T", "$T"}};
32 }
33 }
34 *g = FDH::Define(
35 // Arg defs
36 {"x: T", "dy: T"},
37 // Ret val defs
38 {"dx: T"},
39 // Attr defs
40 {{"T: {half, float, double}"}},
41 // Nodes
42 nodes);
43 return OkStatus();
44 }
45
AbsGrad(const AttrSlice & attrs,FunctionDef * g)46 Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) {
47 // clang-format off
48 return GradForUnaryCwise(g, {
49 {{"sign"}, "Sign", {"x"}, {}, {"dy"}},
50 {{"dx"}, "Mul", {"dy", "sign"}},
51 });
52 // clang-format on
53 }
54 REGISTER_OP_GRADIENT("Abs", AbsGrad);
55
NegGrad(const AttrSlice & attrs,FunctionDef * g)56 Status NegGrad(const AttrSlice& attrs, FunctionDef* g) {
57 // clang-format off
58 return GradForUnaryCwise(g, {
59 {{"dx"}, "Neg", {"dy"}},
60 });
61 // clang-format on
62 }
63 REGISTER_OP_GRADIENT("Neg", NegGrad);
64
InvGrad(const AttrSlice & attrs,FunctionDef * g)65 Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
66 // clang-format off
67 return GradForUnaryCwise(g, {
68 {{"y"}, "Reciprocal", {"x"}},
69 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
70 {{"y2_neg"}, "Neg", {"y2"}},
71 {{"dx"}, "Mul", {"dy", "y2_neg"}}
72 });
73 // clang-format on
74 }
75 REGISTER_OP_GRADIENT("Inv", InvGrad);
76 REGISTER_OP_GRADIENT("Reciprocal", InvGrad);
77
SquareGrad(const AttrSlice & attrs,FunctionDef * g)78 Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
79 // clang-format off
80 return GradForUnaryCwise(g, {
81 FDH::Const("c", int64_t{2}),
82 {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
83 {{"x2"}, "Mul", {"x", "two"}, {}, {"dy"}}, // x * 2
84 {{"dx"}, "Mul", {"dy", "x2"}}, // dy * (x * 2)
85 });
86 // clang-format on
87 }
88 REGISTER_OP_GRADIENT("Square", SquareGrad);
89
SqrtGrad(const AttrSlice & attrs,FunctionDef * g)90 Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
91 // clang-format off
92 return GradForUnaryCwise(g, {
93 {{"y"}, "Sqrt", {"x"}},
94 {{"y_inv"}, "Reciprocal", {"y"}, {}, {"dy"}},
95 FDH::Const("const", 0.5f),
96 {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
97 {{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y
98 {{"dx"}, "Mul", {"dy", "a"}}, // dy * (.5 * 1/y)
99 });
100 // clang-format on
101 }
102 REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
103
RsqrtGrad(const AttrSlice & attrs,FunctionDef * g)104 Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
105 // clang-format off
106 return GradForUnaryCwise(g, {
107 {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
108 {{"y"}, "Rsqrt", {"x"}},
109 FDH::Const("const", -.5f),
110 {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
111 {{"a"}, "Mul", {"neghalf", "x_inv"}}, // -0.5 * 1/x
112 {{"b"}, "Mul", {"a", "y"}}, // -0.5 * 1/x * y
113 {{"dx"}, "Mul", {"dy", "b"}}, // dy * (1/y * .5)
114 });
115 // clang-format on
116 }
117 REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad);
118
ExpGrad(const AttrSlice & attrs,FunctionDef * g)119 Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) {
120 // clang-format off
121 return GradForUnaryCwise(g, {
122 {{"y"}, "Exp", {"x"}},
123 {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
124 });
125 // clang-format on
126 }
127 REGISTER_OP_GRADIENT("Exp", ExpGrad);
128
Expm1Grad(const AttrSlice & attrs,FunctionDef * g)129 Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) {
130 // clang-format off
131 return GradForUnaryCwise(g, {
132 {{"y"}, "Exp", {"x"}},
133 {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
134 });
135 // clang-format on
136 }
137 REGISTER_OP_GRADIENT("Expm1", Expm1Grad);
138
LogGrad(const AttrSlice & attrs,FunctionDef * g)139 Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
140 // clang-format off
141 return GradForUnaryCwise(g, {
142 {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
143 {{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x
144 });
145 // clang-format on
146 }
147 REGISTER_OP_GRADIENT("Log", LogGrad);
148
Log1pGrad(const AttrSlice & attrs,FunctionDef * g)149 Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
150 // clang-format off
151 return GradForUnaryCwise(g, {
152 FDH::Const("const", 1.0f),
153 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
154 {{"a"}, "Add", {"one", "x"}},
155 {{"dx"}, "Div", {"dy", "a"}}, // dy / (1 + x)
156 });
157 // clang-format on
158 }
159 REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
160
SinhGrad(const AttrSlice & attrs,FunctionDef * g)161 Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
162 // clang-format off
163 return GradForUnaryCwise(g, {
164 {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
165 {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(x)
166 });
167 // clang-format on
168 }
169 REGISTER_OP_GRADIENT("Sinh", SinhGrad);
170
CoshGrad(const AttrSlice & attrs,FunctionDef * g)171 Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
172 // clang-format off
173 return GradForUnaryCwise(g, {
174 {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
175 {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(x)
176 });
177 // clang-format on
178 }
179 REGISTER_OP_GRADIENT("Cosh", CoshGrad);
180
TanhGrad(const AttrSlice & attrs,FunctionDef * g)181 Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
182 // clang-format off
183 return GradForUnaryCwise(g, {
184 {{"y"}, "Tanh", {"x"}},
185 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
186 FDH::Const("const", 1.0f),
187 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
188 {{"a"}, "Sub", {"one", "y2"}},
189 {{"dx"}, "Mul", {"dy", "a"}}, // dy * (1 - y*y)
190 });
191 // clang-format on
192 }
193 REGISTER_OP_GRADIENT("Tanh", TanhGrad);
194
AsinhGrad(const AttrSlice & attrs,FunctionDef * g)195 Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) {
196 // clang-format off
197 return GradForUnaryCwise(g, {
198 {{"y"}, "Asinh", {"x"}},
199 {{"cosh"}, "Cosh", {"y"}},
200 {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(y)
201 });
202 // clang-format on
203 }
204 REGISTER_OP_GRADIENT("Asinh", AsinhGrad);
205
AcoshGrad(const AttrSlice & attrs,FunctionDef * g)206 Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) {
207 // clang-format off
208 return GradForUnaryCwise(g, {
209 {{"y"}, "Acosh", {"x"}},
210 {{"sinh"}, "Sinh", {"y"}},
211 {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(y)
212 });
213 // clang-format on
214 }
215 REGISTER_OP_GRADIENT("Acosh", AcoshGrad);
216
AtanhGrad(const AttrSlice & attrs,FunctionDef * g)217 Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) {
218 // clang-format off
219 return GradForUnaryCwise(g, {
220 {{"x2"}, "Square", {"x"}},
221 FDH::Const("const", 1.0f),
222 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
223 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
224 {{"inv"}, "Reciprocal", {"a"}},
225 {{"dx"}, "Mul", {"dy", "inv"}}
226 });
227 // clang-format on
228 }
229 REGISTER_OP_GRADIENT("Atanh", AtanhGrad);
230
SigmoidGrad(const AttrSlice & attrs,FunctionDef * g)231 Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) {
232 // clang-format off
233 return GradForUnaryCwise(g, {
234 {{"y"}, "Sigmoid", {"x"}},
235 FDH::Const("const", 1.0f),
236 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
237 {{"a"}, "Sub", {"one", "y"}, {}, {"dy"}},
238 {{"b"}, "Mul", {"y", "a"}}, // y * (1 - y)
239 {{"dx"}, "Mul", {"dy", "b"}}, // dy * y * (1 - y)
240 });
241 // clang-format on
242 }
243 REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad);
244
SignGrad(const AttrSlice & attrs,FunctionDef * g)245 Status SignGrad(const AttrSlice& attrs, FunctionDef* g) {
246 // clang-format off
247 return GradForUnaryCwise(g, {
248 {{"s"}, "Shape", {"x"}},
249 FDH::Const("zero", 0.f),
250 {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
251 {{"dx"}, "Fill", {"s", "val"}},
252 });
253 // clang-format on
254 }
255 REGISTER_OP_GRADIENT("Sign", SignGrad);
256
SinGrad(const AttrSlice & attrs,FunctionDef * g)257 Status SinGrad(const AttrSlice& attrs, FunctionDef* g) {
258 // clang-format off
259 return GradForUnaryCwise(g, {
260 {{"cos"}, "Cos", {"x"}, {}, {"dy"}},
261 {{"dx"}, "Mul", {"dy", "cos"}}, // dy * cos(x)
262 });
263 // clang-format on
264 }
265 REGISTER_OP_GRADIENT("Sin", SinGrad);
266
CosGrad(const AttrSlice & attrs,FunctionDef * g)267 Status CosGrad(const AttrSlice& attrs, FunctionDef* g) {
268 // clang-format off
269 return GradForUnaryCwise(g, {
270 {{"sin"}, "Sin", {"x"}, {}, {"dy"}},
271 {{"neg"}, "Neg", {"sin"}},
272 {{"dx"}, "Mul", {"dy", "neg"}}, // dy * (-sin(x))
273 });
274 // clang-format on
275 }
276 REGISTER_OP_GRADIENT("Cos", CosGrad);
277
AcosGrad(const AttrSlice & attrs,FunctionDef * g)278 Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) {
279 // clang-format off
280 return GradForUnaryCwise(g, {
281 {{"x2"}, "Square", {"x"}},
282 FDH::Const("const", 1.0f),
283 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
284 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
285 {{"b"}, "Sqrt", {"a"}},
286 {{"inv"}, "Reciprocal", {"b"}},
287 {{"neg"}, "Neg", {"inv"}},
288 {{"dx"}, "Mul", {"dy", "neg"}}
289 });
290 // clang-format on
291 }
292 REGISTER_OP_GRADIENT("Acos", AcosGrad);
293
AsinGrad(const AttrSlice & attrs,FunctionDef * g)294 Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) {
295 // clang-format off
296 return GradForUnaryCwise(g, {
297 {{"x2"}, "Square", {"x"}},
298 FDH::Const("const", 1.0f),
299 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
300 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
301 {{"b"}, "Sqrt", {"a"}},
302 {{"inv"}, "Reciprocal", {"b"}},
303 {{"dx"}, "Mul", {"dy", "inv"}}
304 });
305 // clang-format on
306 }
307 REGISTER_OP_GRADIENT("Asin", AsinGrad);
308
AtanGrad(const AttrSlice & attrs,FunctionDef * g)309 Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) {
310 // clang-format off
311 return GradForUnaryCwise(g, {
312 {{"x2"}, "Square", {"x"}},
313 FDH::Const("const", 1.0f),
314 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
315 {{"a"}, "Add", {"one", "x2"}}, // 1 + x^2
316 {{"inv"}, "Reciprocal", {"a"}},
317 {{"dx"}, "Mul", {"dy", "inv"}}
318 });
319 // clang-format on
320 }
321 REGISTER_OP_GRADIENT("Atan", AtanGrad);
322
TanGrad(const AttrSlice & attrs,FunctionDef * g)323 Status TanGrad(const AttrSlice& attrs, FunctionDef* g) {
324 // clang-format off
325 return GradForUnaryCwise(g, {
326 {{"cosx"}, "Cos", {"x"}},
327 {{"secx"}, "Reciprocal", {"cosx"}},
328 {{"secx2"}, "Square", {"secx"}},
329 {{"dx"}, "Mul", {"dy", "secx2"}}
330 });
331 // clang-format on
332 }
333 REGISTER_OP_GRADIENT("Tan", TanGrad);
334
RealGrad(const AttrSlice & attrs,FunctionDef * g)335 Status RealGrad(const AttrSlice& attrs, FunctionDef* g) {
336 // clang-format off
337 return GradForUnaryCwise(g, {
338 FDH::Const("zero", 0.f),
339 {{"dx"}, "Complex", {"dy", "zero"}},
340 });
341 // clang-format on
342 }
343 REGISTER_OP_GRADIENT("Real", RealGrad);
344
ImagGrad(const AttrSlice & attrs,FunctionDef * g)345 Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
346 // clang-format off
347 return GradForUnaryCwise(g, {
348 FDH::Const("zero", 0.f),
349 {{"dx"}, "Complex", {"zero", "dy"}},
350 });
351 // clang-format on
352 }
353 REGISTER_OP_GRADIENT("Imag", ImagGrad);
354
AngleGrad(const AttrSlice & attrs,FunctionDef * g)355 Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
356 // clang-format off
357 return GradForUnaryCwise(g, {
358 {{"re"}, "Real", {"x"}},
359 {{"im"}, "Imag", {"x"}},
360 {{"z"}, "Complex", {"im", "re"}},
361 {{"z_inv"}, "Reciprocal", {"z"}},
362 {{"neg"}, "Neg", {"z_inv"}},
363 {{"dx"}, "Mul", {"neg", "dy"}},
364 });
365 // clang-format on
366 }
367 REGISTER_OP_GRADIENT("Angle", AngleGrad);
368
ConjGrad(const AttrSlice & attrs,FunctionDef * g)369 Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
370 // clang-format off
371 return GradForUnaryCwise(g, {
372 {{"dx"}, "Conj", {"dy"}},
373 });
374 // clang-format on
375 }
376 REGISTER_OP_GRADIENT("Conj", ConjGrad);
377
CastGrad(const AttrSlice & attrs,FunctionDef * g)378 Status CastGrad(const AttrSlice& attrs, FunctionDef* g) {
379 // clang-format off
380 *g = FDH::Define(
381 // Arg defs
382 {"x: SrcT", "dy: DstT"},
383 // Ret val defs
384 {"dx: SrcT"},
385 // Attr defs
386 {{"SrcT: type"}, {"DstT: type"}},
387 // Nodes
388 {{{"dx"}, "Cast", {"dy"}, {{"SrcT", "$DstT"}, {"DstT", "$SrcT"}}}});
389 return OkStatus();
390 // clang-format on
391 }
392 REGISTER_OP_GRADIENT("Cast", CastGrad);
393
394 // Cwise binary ops
395 //
396 // TODO(zhifengc): This can be arrange as a function in the standard
397 // library.
GradForBinaryCwise(FunctionDef * g,std::vector<FDH::Node> body)398 Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
399 // clang-format off
400 std::vector<FDH::Node> nodes = {
401 {{"sx"}, "Shape", {"x"}},
402 {{"sy"}, "Shape", {"y"}},
403 };
404 nodes.insert(nodes.end(), body.begin(), body.end());
405 std::vector<FDH::Node> reshapes = {
406 {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
407 {{"sum_gx"}, "Sum", {"gx", "rx"}},
408 {{"dx"}, "Reshape", {"sum_gx", "sx"}},
409 {{"sum_gy"}, "Sum", {"gy", "ry"}},
410 {{"dy"}, "Reshape", {"sum_gy", "sy"}},
411 };
412 nodes.insert(nodes.end(), reshapes.begin(), reshapes.end());
413
414 // clang-format on
415 for (auto& n : nodes) {
416 // "BroadcastGradientArgs" doesn't need any attrs.
417 if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
418 n.attr = {{"T", "$T"}};
419 }
420 }
421 *g = FDH::Define(
422 // Arg defs
423 {"x: T", "y: T", "dz: T"},
424 // Ret val defs
425 {"dx: T", "dy: T"},
426 // Attr defs
427 {{"T: {half, float, double}"}},
428 // Nodes
429 nodes);
430 return OkStatus();
431 }
432
AddGrad(const AttrSlice & attrs,FunctionDef * g)433 Status AddGrad(const AttrSlice& attrs, FunctionDef* g) {
434 // clang-format off
435 return GradForBinaryCwise(g, {
436 {{"gx"}, "Identity", {"dz"}},
437 {{"gy"}, "Identity", {"dz"}},
438 });
439 // clang-format on
440 }
441 REGISTER_OP_GRADIENT("Add", AddGrad);
442 REGISTER_OP_GRADIENT("AddV2", AddGrad);
443
SubGrad(const AttrSlice & attrs,FunctionDef * g)444 Status SubGrad(const AttrSlice& attrs, FunctionDef* g) {
445 // clang-format off
446 return GradForBinaryCwise(g, {
447 {{"gx"}, "Identity", {"dz"}},
448 {{"gy"}, "Neg", {"dz"}}, // -dz
449 });
450 // clang-format on
451 }
452 REGISTER_OP_GRADIENT("Sub", SubGrad);
453
MulGrad(const AttrSlice & attrs,FunctionDef * g)454 Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
455 DataType T;
456 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
457 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
458 return GradForBinaryCwise(
459 g, {
460 {{"cy"}, "Conj", {"y"}, {}, {"dz"}},
461 {{"gx"}, "Mul", {"dz", "cy"}}, // dz * Conj(y)
462 {{"cx"}, "Conj", {"x"}, {}, {"dz"}},
463 {{"gy"}, "Mul", {"cx", "dz"}}, // Conj(x) * dz
464 });
465 } else {
466 // clang-format off
467 return GradForBinaryCwise(g, {
468 {{"gx"}, "Mul", {"dz", "y"}}, // dz * y
469 {{"gy"}, "Mul", {"x", "dz"}}, // x * dz
470 });
471 // clang-format on
472 }
473 }
474 REGISTER_OP_GRADIENT("Mul", MulGrad);
475
MulNoNanGrad(const AttrSlice & attrs,FunctionDef * g)476 Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
477 // clang-format off
478 return GradForBinaryCwise(g, {
479 {{"gx"}, "MulNoNan", {"y", "dz"}}, // y * dz
480 {{"gy"}, "MulNoNan", {"x", "dz"}}, // x * dz
481 });
482 // clang-format on
483 }
484 REGISTER_OP_GRADIENT("MulNoNan", MulGrad);
485
DivGrad(const AttrSlice & attrs,FunctionDef * g)486 Status DivGrad(const AttrSlice& attrs, FunctionDef* g) {
487 // clang-format off
488 return GradForBinaryCwise(g, {
489 {{"gx"}, "Div", {"dz", "y"}},
490 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
491 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
492 {{"nx_y2"}, "Div", {"nx", "y2"}},
493 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
494 });
495 // clang-format on
496 }
497 REGISTER_OP_GRADIENT("Div", DivGrad);
498
RealDivGrad(const AttrSlice & attrs,FunctionDef * g)499 Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
500 // clang-format off
501 return GradForBinaryCwise(g, {
502 {{"gx"}, "RealDiv", {"dz", "y"}},
503 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
504 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
505 {{"nx_y2"}, "RealDiv", {"nx", "y2"}},
506 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
507 });
508 // clang-format on
509 }
510 REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
511
DivNoNanGrad(const AttrSlice & attrs,FunctionDef * g)512 Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
513 // clang-format off
514 return GradForBinaryCwise(g, {
515 {{"gx"}, "DivNoNan", {"dz", "y"}},
516 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
517 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
518 {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
519 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
520 });
521 // clang-format on
522 }
523 REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
524
PowGrad(const AttrSlice & attrs,FunctionDef * g)525 Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
526 // clang-format off
527 std::vector<FDH::Node> nodes = {
528 {{"z"}, "Pow", {"x", "y"}},
529 // dz * y * Pow(x, y - 1)
530 FDH::Const("const_zero", 0.0f),
531 FDH::Const("const_one", 1.0f),
532 {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
533 {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
534 {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
535 {{"t1"}, "Pow", {"x", "t0"}},
536 {{"t2"}, "Mul", {"dz", "y"}},
537 {{"gx"}, "Mul", {"t1", "t2"}},
538 {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
539 {{"zeros"}, "ZerosLike", {"x"}}};
540 // clang-format on
541 std::vector<FDH::Node> log_x_handling;
542 DataType T;
543 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
544 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
545 // dz * z * (x != 0 ? Log(x) : 0)
546 // clang-format off
547 log_x_handling = {
548 {{"nz_x"}, "NotEqual", {"x", "zero"}},
549 {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}};
550 // clang-format on
551 } else {
552 // dz * z * (x > 0 ? Log(x) : 0)
553 // clang-format off
554 log_x_handling = {
555 {{"pos_x"}, "Greater", {"x", "zero"}},
556 {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}};
557 // clang-format on
558 }
559 nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end());
560 nodes.push_back({{"t4"}, "Mul", {"dz", "z"}});
561 nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}});
562 return GradForBinaryCwise(g, nodes);
563 }
564 REGISTER_OP_GRADIENT("Pow", PowGrad);
565
XlogyGrad(const AttrSlice & attrs,FunctionDef * g)566 Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
567 // clang-format off
568 return GradForBinaryCwise(g, {
569 {{"zeros"}, "ZerosLike", {"x"}},
570 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
571 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
572 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
573 {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
574 {{"xlogygrad"}, "Xdivy", {"x", "y"}},
575 {{"gx"}, "Mul", {"safe_logy", "dz"}},
576 {{"gy"}, "Mul", {"xlogygrad", "dz"}},
577 });
578 // clang-format on
579 }
580 REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
581
Xlog1pyGrad(const AttrSlice & attrs,FunctionDef * g)582 Status Xlog1pyGrad(const AttrSlice& attrs, FunctionDef* g) {
583 // clang-format off
584 return GradForBinaryCwise(g, {
585 FDH::Const("const", 1.0f),
586 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
587 {{"zeros"}, "ZerosLike", {"x"}},
588 {{"yp1"}, "Add", {"y", "one"}},
589 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
590 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
591 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
592 {{"safe_log1py"}, "Xlog1py", {"is_zero_cast", "y"}},
593 {{"xlog1pygrad"}, "Xdivy", {"x", "yp1"}},
594 {{"gx"}, "Mul", {"safe_log1py", "dz"}},
595 {{"gy"}, "Mul", {"xlog1pygrad", "dz"}},
596 });
597 // clang-format on
598 }
599 REGISTER_OP_GRADIENT("Xlog1py", Xlog1pyGrad);
600
XdivyGrad(const AttrSlice & attrs,FunctionDef * g)601 Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
602 // clang-format off
603 return GradForBinaryCwise(g, {
604 {{"zeros"}, "ZerosLike", {"x"}},
605 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
606 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
607 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
608 {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
609 {{"y2"}, "Square", {"y"}},
610 {{"negy2"}, "Neg", {"y2"}},
611 {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
612 {{"gx"}, "Mul", {"safe_divy", "dz"}},
613 {{"gy"}, "Mul", {"xdivygrad", "dz"}},
614 });
615 // clang-format on
616 }
617 REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
618
SquaredDifferenceGrad(const AttrSlice & attrs,FunctionDef * g)619 Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) {
620 // clang-format off
621 return GradForBinaryCwise(g, {
622 FDH::Const("c", int64_t{2}),
623 {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
624 {{"x_sub_y"}, "Sub", {"x", "y"}},
625 {{"two_x_sub_y"}, "Mul", {"two", "x_sub_y"}}, // 2 * (x - y)
626 {{"gx"}, "Mul", {"two_x_sub_y", "dz"}},
627 {{"gy"}, "Neg", {"gx"}}
628 });
629 // clang-format on
630 }
631 REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad);
632
MaximumMinimumGradHelper(const string & comparator,const AttrSlice & attrs,FunctionDef * g)633 Status MaximumMinimumGradHelper(const string& comparator,
634 const AttrSlice& attrs, FunctionDef* g) {
635 // clang-format off
636 return GradForBinaryCwise(g, {
637 {{"c"}, comparator, {"x", "y"}, {}, {"dz"}},
638 {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
639 {{"gx"}, "Mul", {"dz", "mask"}},
640 {{"gy"}, "Sub", {"dz", "gx"}},
641 });
642 // clang-format on
643 }
644
MaximumGrad(const AttrSlice & attrs,FunctionDef * g)645 Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) {
646 return MaximumMinimumGradHelper("GreaterEqual", attrs, g);
647 }
648 REGISTER_OP_GRADIENT("Maximum", MaximumGrad);
649
MinimumGrad(const AttrSlice & attrs,FunctionDef * g)650 Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) {
651 return MaximumMinimumGradHelper("LessEqual", attrs, g);
652 }
653 REGISTER_OP_GRADIENT("Minimum", MinimumGrad);
654
ComplexGrad(const AttrSlice & attrs,FunctionDef * g)655 Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) {
656 // clang-format off
657 return GradForBinaryCwise(g, {
658 {{"gx"}, "Real", {"dz"}},
659 {{"gy"}, "Imag", {"dz"}},
660 });
661 // clang-format on
662 }
663 REGISTER_OP_GRADIENT("Complex", ComplexGrad);
664
665 // Cwise ternary ops.
SelectGrad(const AttrSlice & attrs,FunctionDef * g)666 Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) {
667 // clang-format off
668 *g = FDH::Define(
669 {"c:bool", "x:T", "y:T", "dz:T"},
670 {"dc:bool", "dx:T", "dy:T"},
671 {{"T: {half, float, double}"}},
672 {
673 {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}, {"dz"}},
674 {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}, {"dz"}},
675 {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}},
676 {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}},
677 });
678 // clang-format on
679 return OkStatus();
680 }
681 REGISTER_OP_GRADIENT("Select", SelectGrad);
682
683 // N-ry ops
684 // REGISTER_OP_GRADIENT("AddN", AddNGrad);
685
686 // Reduction ops
687 //
688 // TODO(zhifengc): This helper is pretty ugly. Do something better.
689 // TODO(zhifengc): This can be arrange as a function in the standard library.
GradForReductionOp(FunctionDef * g,std::vector<FDH::Node> body)690 Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
691 // Shape manipulation nodes.
692
693 // clang-format off
694 std::vector<FDH::Node> nodes = {
695 {{"x_shape"}, "Shape", {"x"}},
696 {{"x_rank"}, "Rank", {"x"}},
697 {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}},
698 FDH::Const("zero", 0),
699 FDH::Const("one", 1),
700 // stitch_idx0 = Range(0, x_rank, 1)
701 {{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
702 {{"T", DT_INT32}}},
703 {{"y_shape"}, "DynamicStitch",
704 {"stitch_idx0:output:0", "i",
705 "x_shape:output:0", "stitch_val1:output:0"},
706 {{"N", 2}, {"T", DT_INT32}}},
707 {{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
708 {{"T", DT_INT32}}},
709 {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
710 };
711 // clang-format on
712 nodes.insert(nodes.end(), body.begin(), body.end());
713 for (auto& n : nodes) {
714 if (n.attr.empty()) {
715 n.attr = {{"T", "$T"}};
716 }
717 }
718 // "Range" doesn't need any attr.
719 nodes.push_back({{"stitch_idx0"},
720 "Range",
721 {"zero:output:0", "x_rank:output:0", "one:output:0"},
722 {}});
723 *g = FDH::Create("_",
724 // Input defs
725 {"x:T", "i:int32", "dy:T"},
726 // Ret val defs
727 {"dx:T", "di:int32"},
728 // Attr defs
729 {{"T: {half, float, double}"}},
730 // Nodes
731 nodes,
732 // Return values
733 {{"dx", "dx:output:0"}, {"di", "di:y:0"}});
734 return OkStatus();
735 }
736
SumGrad(const AttrSlice & attrs,FunctionDef * g)737 Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
738 // clang-format off
739 return GradForReductionOp(g, {
740 {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
741 {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
742 });
743 // clang-format on
744 }
745 REGISTER_OP_GRADIENT("Sum", SumGrad);
746
MeanGrad(const AttrSlice & attrs,FunctionDef * g)747 Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
748 // clang-format off
749 return GradForReductionOp(g, {
750 {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
751 {{"T", DT_INT32}}},
752 {{"factor_T"}, "Cast", {"factor:output:0"},
753 {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
754 {{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
755 {{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
756 {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
757 });
758 // clang-format on
759 }
760 REGISTER_OP_GRADIENT("Mean", MeanGrad);
761
762 // REGISTER_OP_GRADIENT("Prod", ProdGrad);
763 // REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad);
764 // REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
765 // REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
766 // REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
767 // REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad);
768 // REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
769 // REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
770 // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
771 // REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
772
MinMaxGradHelper(const string & op,const AttrSlice & attrs,FunctionDef * g)773 Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
774 FunctionDef* g) {
775 // clang-format off
776 *g = FDH::Define(
777 // Arg defs
778 {"x:T", "i:int32", "dy:T"},
779 // Ret val defs
780 {"dx:T", "di:int32"},
781 // Attr defs
782 {{"T: {half, float, double}"}},
783 {
784 // keep_dims because we need to do x == y, which requires x
785 // and y are broadcastable.
786 {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}},
787 {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}},
788 {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
789 {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}},
790 {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}},
791 {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
792 {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}},
793 {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}},
794 {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
795 });
796 // clang-format on
797 return OkStatus();
798 }
799
MaxGrad(const AttrSlice & attrs,FunctionDef * g)800 Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) {
801 return MinMaxGradHelper("Max", attrs, g);
802 }
803 REGISTER_OP_GRADIENT("Max", MaxGrad);
804
MinGrad(const AttrSlice & attrs,FunctionDef * g)805 Status MinGrad(const AttrSlice& attrs, FunctionDef* g) {
806 return MinMaxGradHelper("Min", attrs, g);
807 }
808 REGISTER_OP_GRADIENT("Min", MinGrad);
809
MatMulGradHelper(FunctionDef * g,const string & opname,const string & attr_adj_x,const string & attr_adj_y,const string & x0,bool ax0,const string & x1,bool ax1,const string & y0,bool ay0,const string & y1,bool ay1,bool enable_broadcasting)810 static Status MatMulGradHelper(FunctionDef* g, const string& opname,
811 const string& attr_adj_x,
812 const string& attr_adj_y, const string& x0,
813 bool ax0, const string& x1, bool ax1,
814 const string& y0, bool ay0, const string& y1,
815 bool ay1, bool enable_broadcasting) {
816 // The final outputs are "dx" and "dy". If we're broadcasting compute
817 // intermediate nodes for now.
818 std::vector<FDH::Node> nodes = {
819 {{(enable_broadcasting ? "gx" : "dx")},
820 opname,
821 {x0, x1},
822 {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
823 {{(enable_broadcasting ? "gy" : "dy")},
824 opname,
825 {y0, y1},
826 {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
827 };
828 // TODO(anudhyan): Figure out a way to inspect the static shapes of "x" and
829 // "y". If they have the same batch dimensions, then we can omit adding the
830 // broadcasting-specific ops.
831 if (enable_broadcasting) {
832 std::vector<FDH::Node> unbroadcast_gradients = {
833 FDH::Const<int32>("zero", gtl::ArraySlice<int32>{0}),
834 FDH::Const<int32>("one", gtl::ArraySlice<int32>{1}),
835 FDH::Const<int32>("minustwo", gtl::ArraySlice<int32>{-2}),
836 // Compute the batch shapes of the inputs (all but last two dims).
837 {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
838 {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
839 {{"batch_sx"},
840 "StridedSlice",
841 {"sx", "zero", "minustwo", "one"},
842 {{"T", DT_INT32}, {"Index", DT_INT32}}},
843 {{"batch_sy"},
844 "StridedSlice",
845 {"sy", "zero", "minustwo", "one"},
846 {{"T", DT_INT32}, {"Index", DT_INT32}}},
847 // Sum along dimensions that the inputs were broadcasted across.
848 {{"rx", "ry"}, "BroadcastGradientArgs", {"batch_sx", "batch_sy"}},
849 {{"sum_gx"}, "Sum", {"gx", "rx"}, {{"T", "$T"}}},
850 {{"sum_gy"}, "Sum", {"gy", "ry"}, {{"T", "$T"}}},
851 {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
852 {{"dy"}, "Reshape", {"sum_gy", "sy"}, {{"T", "$T"}}}};
853 nodes.insert(nodes.end(), unbroadcast_gradients.begin(),
854 unbroadcast_gradients.end());
855 }
856 *g = FDH::Define(
857 // Arg defs
858 {"x: T", "y: T", "dz: T"},
859 // Ret val defs
860 {"dx: T", "dy: T"},
861 // Attr defs
862 {{"T: {half, float, double}"}},
863 // Nodes
864 nodes);
865 return OkStatus();
866 }
867
MatMulGradCommon(const string & opname,const string & attr_adj_x,const string & attr_adj_y,const AttrSlice & attrs,FunctionDef * g,bool enable_broadcasting)868 Status MatMulGradCommon(const string& opname, const string& attr_adj_x,
869 const string& attr_adj_y, const AttrSlice& attrs,
870 FunctionDef* g, bool enable_broadcasting) {
871 DataType T;
872 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
873 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
874 return errors::Unimplemented(
875 "MatMul gradient for complex is not supported yet.");
876 }
877 bool ta;
878 bool tb;
879 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
880 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
881 if (!ta && !tb) {
882 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
883 true, "x", true, "dz", false, enable_broadcasting);
884 }
885 if (!ta && tb) {
886 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
887 false, "dz", true, "x", false, enable_broadcasting);
888 }
889 if (ta && !tb) {
890 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
891 true, "x", false, "dz", false, enable_broadcasting);
892 }
893 CHECK(ta && tb);
894 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
895 true, "dz", true, "x", true, enable_broadcasting);
896 }
897
MatMulGrad(const AttrSlice & attrs,FunctionDef * g)898 Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
899 return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g,
900 false /* enable_broadcasting */);
901 }
902 REGISTER_OP_GRADIENT("MatMul", MatMulGrad);
903
BatchMatMulGrad(const AttrSlice & attrs,FunctionDef * g)904 Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
905 return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g,
906 false /* enable_broadcasting */);
907 }
908 REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad);
909
BatchMatMulV2Grad(const AttrSlice & attrs,FunctionDef * g)910 Status BatchMatMulV2Grad(const AttrSlice& attrs, FunctionDef* g) {
911 return MatMulGradCommon("BatchMatMulV2", "adj_x", "adj_y", attrs, g,
912 true /* enable_broadcasting */);
913 }
914 REGISTER_OP_GRADIENT("BatchMatMulV2", BatchMatMulV2Grad);
915
916 // REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad);
917
918 // Comparison ops.
919 REGISTER_OP_NO_GRADIENT("Less");
920 REGISTER_OP_NO_GRADIENT("LessEqual");
921 REGISTER_OP_NO_GRADIENT("Greater");
922 REGISTER_OP_NO_GRADIENT("GreaterEqual");
923 REGISTER_OP_NO_GRADIENT("Equal");
924 REGISTER_OP_NO_GRADIENT("NotEqual");
925
926 // Logical ops.
927 REGISTER_OP_NO_GRADIENT("LogicalAnd");
928 REGISTER_OP_NO_GRADIENT("LogicalOr");
929 REGISTER_OP_NO_GRADIENT("LogicalNot");
930
931 // Sequence generation ops.
932 REGISTER_OP_NO_GRADIENT("Range");
933 REGISTER_OP_NO_GRADIENT("LinSpace");
934
935 REGISTER_OP_NO_GRADIENT("Floor");
936 REGISTER_OP_NO_GRADIENT("FloorDiv");
937 REGISTER_OP_NO_GRADIENT("TruncateDiv");
938
939 } // end namespace tensorflow
940