1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17 #include "tensorflow/core/framework/function.h"
18 #include "tensorflow/core/lib/core/errors.h"
19
20 namespace tensorflow {
21
22 typedef FunctionDefHelper FDH;
23
24 // Cwise binary ops
GradForUnaryCwise(FunctionDef * g,std::vector<FDH::Node> nodes)25 Status GradForUnaryCwise(FunctionDef* g, std::vector<FDH::Node> nodes) {
26 for (auto& n : nodes) {
27 if (n.attr.empty()) {
28 n.attr = {{"T", "$T"}};
29 }
30 }
31 *g = FDH::Define(
32 // Arg defs
33 {"x: T", "dy: T"},
34 // Ret val defs
35 {"dx: T"},
36 // Attr defs
37 {{"T: {half, float, double}"}},
38 // Nodes
39 nodes);
40 return Status::OK();
41 }
42
AbsGrad(const AttrSlice & attrs,FunctionDef * g)43 Status AbsGrad(const AttrSlice& attrs, FunctionDef* g) {
44 // clang-format off
45 return GradForUnaryCwise(g, {
46 {{"sign"}, "Sign", {"x"}, {}, {"dy"}},
47 {{"dx"}, "Mul", {"dy", "sign"}},
48 });
49 // clang-format on
50 }
51 REGISTER_OP_GRADIENT("Abs", AbsGrad);
52
NegGrad(const AttrSlice & attrs,FunctionDef * g)53 Status NegGrad(const AttrSlice& attrs, FunctionDef* g) {
54 // clang-format off
55 return GradForUnaryCwise(g, {
56 {{"dx"}, "Neg", {"dy"}},
57 });
58 // clang-format on
59 }
60 REGISTER_OP_GRADIENT("Neg", NegGrad);
61
InvGrad(const AttrSlice & attrs,FunctionDef * g)62 Status InvGrad(const AttrSlice& attrs, FunctionDef* g) {
63 // clang-format off
64 return GradForUnaryCwise(g, {
65 {{"y"}, "Reciprocal", {"x"}},
66 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
67 {{"y2_neg"}, "Neg", {"y2"}},
68 {{"dx"}, "Mul", {"dy", "y2_neg"}}
69 });
70 // clang-format on
71 }
72 REGISTER_OP_GRADIENT("Inv", InvGrad);
73 REGISTER_OP_GRADIENT("Reciprocal", InvGrad);
74
SquareGrad(const AttrSlice & attrs,FunctionDef * g)75 Status SquareGrad(const AttrSlice& attrs, FunctionDef* g) {
76 // clang-format off
77 return GradForUnaryCwise(g, {
78 FDH::Const("c", 2LL),
79 {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
80 {{"x2"}, "Mul", {"x", "two"}, {}, {"dy"}}, // x * 2
81 {{"dx"}, "Mul", {"dy", "x2"}}, // dy * (x * 2)
82 });
83 // clang-format on
84 }
85 REGISTER_OP_GRADIENT("Square", SquareGrad);
86
SqrtGrad(const AttrSlice & attrs,FunctionDef * g)87 Status SqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
88 // clang-format off
89 return GradForUnaryCwise(g, {
90 {{"y"}, "Sqrt", {"x"}},
91 {{"y_inv"}, "Reciprocal", {"y"}, {}, {"dy"}},
92 FDH::Const("const", 0.5f),
93 {{"half"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
94 {{"a"}, "Mul", {"half", "y_inv"}}, // .5 * 1/y
95 {{"dx"}, "Mul", {"dy", "a"}}, // dy * (.5 * 1/y)
96 });
97 // clang-format on
98 }
99 REGISTER_OP_GRADIENT("Sqrt", SqrtGrad);
100
RsqrtGrad(const AttrSlice & attrs,FunctionDef * g)101 Status RsqrtGrad(const AttrSlice& attrs, FunctionDef* g) {
102 // clang-format off
103 return GradForUnaryCwise(g, {
104 {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
105 {{"y"}, "Rsqrt", {"x"}},
106 FDH::Const("const", -.5f),
107 {{"neghalf"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
108 {{"a"}, "Mul", {"neghalf", "x_inv"}}, // -0.5 * 1/x
109 {{"b"}, "Mul", {"a", "y"}}, // -0.5 * 1/x * y
110 {{"dx"}, "Mul", {"dy", "b"}}, // dy * (1/y * .5)
111 });
112 // clang-format on
113 }
114 REGISTER_OP_GRADIENT("Rsqrt", RsqrtGrad);
115
ExpGrad(const AttrSlice & attrs,FunctionDef * g)116 Status ExpGrad(const AttrSlice& attrs, FunctionDef* g) {
117 // clang-format off
118 return GradForUnaryCwise(g, {
119 {{"y"}, "Exp", {"x"}},
120 {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
121 });
122 // clang-format on
123 }
124 REGISTER_OP_GRADIENT("Exp", ExpGrad);
125
Expm1Grad(const AttrSlice & attrs,FunctionDef * g)126 Status Expm1Grad(const AttrSlice& attrs, FunctionDef* g) {
127 // clang-format off
128 return GradForUnaryCwise(g, {
129 {{"y"}, "Exp", {"x"}},
130 {{"dx"}, "Mul", {"dy", "y"}}, // dy * y
131 });
132 // clang-format on
133 }
134 REGISTER_OP_GRADIENT("Expm1", Expm1Grad);
135
LogGrad(const AttrSlice & attrs,FunctionDef * g)136 Status LogGrad(const AttrSlice& attrs, FunctionDef* g) {
137 // clang-format off
138 return GradForUnaryCwise(g, {
139 {{"x_inv"}, "Reciprocal", {"x"}, {}, {"dy"}},
140 {{"dx"}, "Mul", {"dy", "x_inv"}}, // dy * 1/x
141 });
142 // clang-format on
143 }
144 REGISTER_OP_GRADIENT("Log", LogGrad);
145
Log1pGrad(const AttrSlice & attrs,FunctionDef * g)146 Status Log1pGrad(const AttrSlice& attrs, FunctionDef* g) {
147 // clang-format off
148 return GradForUnaryCwise(g, {
149 FDH::Const("const", 1.0f),
150 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
151 {{"a"}, "Add", {"one", "x"}},
152 {{"dx"}, "Div", {"dy", "a"}}, // dy / (1 + x)
153 });
154 // clang-format on
155 }
156 REGISTER_OP_GRADIENT("Log1p", Log1pGrad);
157
SinhGrad(const AttrSlice & attrs,FunctionDef * g)158 Status SinhGrad(const AttrSlice& attrs, FunctionDef* g) {
159 // clang-format off
160 return GradForUnaryCwise(g, {
161 {{"cosh"}, "Cosh", {"x"}, {}, {"dy"}},
162 {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(x)
163 });
164 // clang-format on
165 }
166 REGISTER_OP_GRADIENT("Sinh", SinhGrad);
167
CoshGrad(const AttrSlice & attrs,FunctionDef * g)168 Status CoshGrad(const AttrSlice& attrs, FunctionDef* g) {
169 // clang-format off
170 return GradForUnaryCwise(g, {
171 {{"sinh"}, "Sinh", {"x"}, {}, {"dy"}},
172 {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(x)
173 });
174 // clang-format on
175 }
176 REGISTER_OP_GRADIENT("Cosh", CoshGrad);
177
TanhGrad(const AttrSlice & attrs,FunctionDef * g)178 Status TanhGrad(const AttrSlice& attrs, FunctionDef* g) {
179 // clang-format off
180 return GradForUnaryCwise(g, {
181 {{"y"}, "Tanh", {"x"}},
182 {{"y2"}, "Square", {"y"}, {}, {"dy"}},
183 FDH::Const("const", 1.0f),
184 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
185 {{"a"}, "Sub", {"one", "y2"}},
186 {{"dx"}, "Mul", {"dy", "a"}}, // dy * (1 - y*y)
187 });
188 // clang-format on
189 }
190 REGISTER_OP_GRADIENT("Tanh", TanhGrad);
191
AsinhGrad(const AttrSlice & attrs,FunctionDef * g)192 Status AsinhGrad(const AttrSlice& attrs, FunctionDef* g) {
193 // clang-format off
194 return GradForUnaryCwise(g, {
195 {{"y"}, "Asinh", {"x"}},
196 {{"cosh"}, "Cosh", {"y"}},
197 {{"dx"}, "Mul", {"dy", "cosh"}}, // dy * cosh(y)
198 });
199 // clang-format on
200 }
201 REGISTER_OP_GRADIENT("Asinh", AsinhGrad);
202
AcoshGrad(const AttrSlice & attrs,FunctionDef * g)203 Status AcoshGrad(const AttrSlice& attrs, FunctionDef* g) {
204 // clang-format off
205 return GradForUnaryCwise(g, {
206 {{"y"}, "Acosh", {"x"}},
207 {{"sinh"}, "Sinh", {"y"}},
208 {{"dx"}, "Mul", {"dy", "sinh"}}, // dy * sinh(y)
209 });
210 // clang-format on
211 }
212 REGISTER_OP_GRADIENT("Acosh", AcoshGrad);
213
AtanhGrad(const AttrSlice & attrs,FunctionDef * g)214 Status AtanhGrad(const AttrSlice& attrs, FunctionDef* g) {
215 // clang-format off
216 return GradForUnaryCwise(g, {
217 {{"x2"}, "Square", {"x"}},
218 FDH::Const("const", 1.0f),
219 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
220 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
221 {{"inv"}, "Reciprocal", {"a"}},
222 {{"dx"}, "Mul", {"dy", "inv"}}
223 });
224 // clang-format on
225 }
226 REGISTER_OP_GRADIENT("Atanh", AtanhGrad);
227
SigmoidGrad(const AttrSlice & attrs,FunctionDef * g)228 Status SigmoidGrad(const AttrSlice& attrs, FunctionDef* g) {
229 // clang-format off
230 return GradForUnaryCwise(g, {
231 {{"y"}, "Sigmoid", {"x"}},
232 FDH::Const("const", 1.0f),
233 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
234 {{"a"}, "Sub", {"one", "y"}, {}, {"dy"}},
235 {{"b"}, "Mul", {"y", "a"}}, // y * (1 - y)
236 {{"dx"}, "Mul", {"dy", "b"}}, // dy * y * (1 - y)
237 });
238 // clang-format on
239 }
240 REGISTER_OP_GRADIENT("Sigmoid", SigmoidGrad);
241
SignGrad(const AttrSlice & attrs,FunctionDef * g)242 Status SignGrad(const AttrSlice& attrs, FunctionDef* g) {
243 // clang-format off
244 return GradForUnaryCwise(g, {
245 {{"s"}, "Shape", {"x"}},
246 FDH::Const("zero", 0.f),
247 {{"val"}, "Cast", {"zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
248 {{"dx"}, "Fill", {"s", "val"}},
249 });
250 // clang-format on
251 }
252 REGISTER_OP_GRADIENT("Sign", SignGrad);
253
SinGrad(const AttrSlice & attrs,FunctionDef * g)254 Status SinGrad(const AttrSlice& attrs, FunctionDef* g) {
255 // clang-format off
256 return GradForUnaryCwise(g, {
257 {{"cos"}, "Cos", {"x"}, {}, {"dy"}},
258 {{"dx"}, "Mul", {"dy", "cos"}}, // dy * cos(x)
259 });
260 // clang-format on
261 }
262 REGISTER_OP_GRADIENT("Sin", SinGrad);
263
CosGrad(const AttrSlice & attrs,FunctionDef * g)264 Status CosGrad(const AttrSlice& attrs, FunctionDef* g) {
265 // clang-format off
266 return GradForUnaryCwise(g, {
267 {{"sin"}, "Sin", {"x"}, {}, {"dy"}},
268 {{"neg"}, "Neg", {"sin"}},
269 {{"dx"}, "Mul", {"dy", "neg"}}, // dy * (-sin(x))
270 });
271 // clang-format on
272 }
273 REGISTER_OP_GRADIENT("Cos", CosGrad);
274
AcosGrad(const AttrSlice & attrs,FunctionDef * g)275 Status AcosGrad(const AttrSlice& attrs, FunctionDef* g) {
276 // clang-format off
277 return GradForUnaryCwise(g, {
278 {{"x2"}, "Square", {"x"}},
279 FDH::Const("const", 1.0f),
280 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
281 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
282 {{"b"}, "Sqrt", {"a"}},
283 {{"inv"}, "Reciprocal", {"b"}},
284 {{"neg"}, "Neg", {"inv"}},
285 {{"dx"}, "Mul", {"dy", "neg"}}
286 });
287 // clang-format on
288 }
289 REGISTER_OP_GRADIENT("Acos", AcosGrad);
290
AsinGrad(const AttrSlice & attrs,FunctionDef * g)291 Status AsinGrad(const AttrSlice& attrs, FunctionDef* g) {
292 // clang-format off
293 return GradForUnaryCwise(g, {
294 {{"x2"}, "Square", {"x"}},
295 FDH::Const("const", 1.0f),
296 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
297 {{"a"}, "Sub", {"one", "x2"}}, // 1 - x^2
298 {{"b"}, "Sqrt", {"a"}},
299 {{"inv"}, "Reciprocal", {"b"}},
300 {{"dx"}, "Mul", {"dy", "inv"}}
301 });
302 // clang-format on
303 }
304 REGISTER_OP_GRADIENT("Asin", AsinGrad);
305
AtanGrad(const AttrSlice & attrs,FunctionDef * g)306 Status AtanGrad(const AttrSlice& attrs, FunctionDef* g) {
307 // clang-format off
308 return GradForUnaryCwise(g, {
309 {{"x2"}, "Square", {"x"}},
310 FDH::Const("const", 1.0f),
311 {{"one"}, "Cast", {"const"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
312 {{"a"}, "Add", {"one", "x2"}}, // 1 + x^2
313 {{"inv"}, "Reciprocal", {"a"}},
314 {{"dx"}, "Mul", {"dy", "inv"}}
315 });
316 // clang-format on
317 }
318 REGISTER_OP_GRADIENT("Atan", AtanGrad);
319
TanGrad(const AttrSlice & attrs,FunctionDef * g)320 Status TanGrad(const AttrSlice& attrs, FunctionDef* g) {
321 // clang-format off
322 return GradForUnaryCwise(g, {
323 {{"cosx"}, "Cos", {"x"}},
324 {{"secx"}, "Reciprocal", {"cosx"}},
325 {{"secx2"}, "Square", {"secx"}},
326 {{"dx"}, "Mul", {"dy", "secx2"}}
327 });
328 // clang-format on
329 }
330 REGISTER_OP_GRADIENT("Tan", TanGrad);
331
RealGrad(const AttrSlice & attrs,FunctionDef * g)332 Status RealGrad(const AttrSlice& attrs, FunctionDef* g) {
333 // clang-format off
334 return GradForUnaryCwise(g, {
335 FDH::Const("zero", 0.f),
336 {{"dx"}, "Complex", {"dy", "zero"}},
337 });
338 // clang-format on
339 }
340 REGISTER_OP_GRADIENT("Real", RealGrad);
341
ImagGrad(const AttrSlice & attrs,FunctionDef * g)342 Status ImagGrad(const AttrSlice& attrs, FunctionDef* g) {
343 // clang-format off
344 return GradForUnaryCwise(g, {
345 FDH::Const("zero", 0.f),
346 {{"dx"}, "Complex", {"zero", "dy"}},
347 });
348 // clang-format on
349 }
350 REGISTER_OP_GRADIENT("Imag", ImagGrad);
351
AngleGrad(const AttrSlice & attrs,FunctionDef * g)352 Status AngleGrad(const AttrSlice& attrs, FunctionDef* g) {
353 // clang-format off
354 return GradForUnaryCwise(g, {
355 {{"re"}, "Real", {"x"}},
356 {{"im"}, "Imag", {"x"}},
357 {{"z"}, "Complex", {"im", "re"}},
358 {{"z_inv"}, "Reciprocal", {"z"}},
359 {{"neg"}, "Neg", {"z_inv"}},
360 {{"dx"}, "Mul", {"neg", "dy"}},
361 });
362 // clang-format on
363 }
364 REGISTER_OP_GRADIENT("Angle", AngleGrad);
365
ConjGrad(const AttrSlice & attrs,FunctionDef * g)366 Status ConjGrad(const AttrSlice& attrs, FunctionDef* g) {
367 // clang-format off
368 return GradForUnaryCwise(g, {
369 {{"dx"}, "Conj", {"dy"}},
370 });
371 // clang-format on
372 }
373 REGISTER_OP_GRADIENT("Conj", ConjGrad);
374
CastGrad(const AttrSlice & attrs,FunctionDef * g)375 Status CastGrad(const AttrSlice& attrs, FunctionDef* g) {
376 // clang-format off
377 *g = FDH::Define(
378 // Arg defs
379 {"x: SrcT", "dy: DstT"},
380 // Ret val defs
381 {"dx: SrcT"},
382 // Attr defs
383 {{"SrcT: type"}, {"DstT: type"}},
384 // Nodes
385 {{{"dx"}, "Cast", {"dy"}, {{"SrcT", "$DstT"}, {"DstT", "$SrcT"}}}});
386 return Status::OK();
387 // clang-format on
388 }
389 REGISTER_OP_GRADIENT("Cast", CastGrad);
390
391 // Cwise binary ops
392 //
393 // TODO(zhifengc): This can be arrange as a function in the standard
394 // library.
GradForBinaryCwise(FunctionDef * g,std::vector<FDH::Node> body)395 Status GradForBinaryCwise(FunctionDef* g, std::vector<FDH::Node> body) {
396 // clang-format off
397 std::vector<FDH::Node> nodes = {
398 {{"sx"}, "Shape", {"x"}},
399 {{"sy"}, "Shape", {"y"}},
400 };
401 nodes.insert(nodes.end(), body.begin(), body.end());
402 std::vector<FDH::Node> reshapes = {
403 {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "sy"}},
404 {{"sum_gx"}, "Sum", {"gx", "rx"}},
405 {{"dx"}, "Reshape", {"sum_gx", "sx"}},
406 {{"sum_gy"}, "Sum", {"gy", "ry"}},
407 {{"dy"}, "Reshape", {"sum_gy", "sy"}},
408 };
409 nodes.insert(nodes.end(), reshapes.begin(), reshapes.end());
410
411 // clang-format on
412 for (auto& n : nodes) {
413 // "BroadcastGradientArgs" doesn't need any attrs.
414 if (n.attr.empty() && n.op != "BroadcastGradientArgs") {
415 n.attr = {{"T", "$T"}};
416 }
417 }
418 *g = FDH::Define(
419 // Arg defs
420 {"x: T", "y: T", "dz: T"},
421 // Ret val defs
422 {"dx: T", "dy: T"},
423 // Attr defs
424 {{"T: {half, float, double}"}},
425 // Nodes
426 nodes);
427 return Status::OK();
428 }
429
AddGrad(const AttrSlice & attrs,FunctionDef * g)430 Status AddGrad(const AttrSlice& attrs, FunctionDef* g) {
431 // clang-format off
432 return GradForBinaryCwise(g, {
433 {{"gx"}, "Identity", {"dz"}},
434 {{"gy"}, "Identity", {"dz"}},
435 });
436 // clang-format on
437 }
438 REGISTER_OP_GRADIENT("Add", AddGrad);
439
SubGrad(const AttrSlice & attrs,FunctionDef * g)440 Status SubGrad(const AttrSlice& attrs, FunctionDef* g) {
441 // clang-format off
442 return GradForBinaryCwise(g, {
443 {{"gx"}, "Identity", {"dz"}},
444 {{"gy"}, "Neg", {"dz"}}, // -dz
445 });
446 // clang-format on
447 }
448 REGISTER_OP_GRADIENT("Sub", SubGrad);
449
MulGrad(const AttrSlice & attrs,FunctionDef * g)450 Status MulGrad(const AttrSlice& attrs, FunctionDef* g) {
451 DataType T;
452 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
453 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
454 return GradForBinaryCwise(
455 g, {
456 {{"cy"}, "Conj", {"y"}, {}, {"dz"}},
457 {{"gx"}, "Mul", {"dz", "cy"}}, // dz * Conj(y)
458 {{"cx"}, "Conj", {"x"}, {}, {"dz"}},
459 {{"gy"}, "Mul", {"cx", "dz"}}, // Conj(x) * dz
460 });
461 } else {
462 // clang-format off
463 return GradForBinaryCwise(g, {
464 {{"gx"}, "Mul", {"dz", "y"}}, // dz * y
465 {{"gy"}, "Mul", {"x", "dz"}}, // x * dz
466 });
467 // clang-format on
468 }
469 }
470 REGISTER_OP_GRADIENT("Mul", MulGrad);
471
MulNoNanGrad(const AttrSlice & attrs,FunctionDef * g)472 Status MulNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
473 // clang-format off
474 return GradForBinaryCwise(g, {
475 {{"gx"}, "MulNoNan", {"y", "dz"}}, // y * dz
476 {{"gy"}, "MulNoNan", {"x", "dz"}}, // x * dz
477 });
478 // clang-format on
479 }
480 REGISTER_OP_GRADIENT("MulNoNan", MulGrad);
481
DivGrad(const AttrSlice & attrs,FunctionDef * g)482 Status DivGrad(const AttrSlice& attrs, FunctionDef* g) {
483 // clang-format off
484 return GradForBinaryCwise(g, {
485 {{"gx"}, "Div", {"dz", "y"}},
486 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
487 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
488 {{"nx_y2"}, "Div", {"nx", "y2"}},
489 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
490 });
491 // clang-format on
492 }
493 REGISTER_OP_GRADIENT("Div", DivGrad);
494
RealDivGrad(const AttrSlice & attrs,FunctionDef * g)495 Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
496 // clang-format off
497 return GradForBinaryCwise(g, {
498 {{"gx"}, "RealDiv", {"dz", "y"}},
499 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
500 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
501 {{"nx_y2"}, "RealDiv", {"nx", "y2"}},
502 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
503 });
504 // clang-format on
505 }
506 REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
507
DivNoNanGrad(const AttrSlice & attrs,FunctionDef * g)508 Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
509 // clang-format off
510 return GradForBinaryCwise(g, {
511 {{"gx"}, "DivNoNan", {"dz", "y"}},
512 {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
513 {{"y2"}, "Square", {"y"}, {}, {"dz"}},
514 {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
515 {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
516 });
517 // clang-format on
518 }
519 REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
520
PowGrad(const AttrSlice & attrs,FunctionDef * g)521 Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
522 // clang-format off
523 std::vector<FDH::Node> nodes = {
524 {{"z"}, "Pow", {"x", "y"}},
525 // dz * y * Pow(x, y - 1)
526 FDH::Const("const_zero", 0.0f),
527 FDH::Const("const_one", 1.0f),
528 {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
529 {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
530 {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
531 {{"t1"}, "Pow", {"x", "t0"}},
532 {{"t2"}, "Mul", {"dz", "y"}},
533 {{"gx"}, "Mul", {"t1", "t2"}},
534 {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
535 {{"zeros"}, "ZerosLike", {"x"}}};
536 // clang-format on
537 std::vector<FDH::Node> log_x_handling;
538 DataType T;
539 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
540 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
541 // dz * z * (x != 0 ? Log(x) : 0)
542 // clang-format off
543 log_x_handling = {
544 {{"nz_x"}, "NotEqual", {"x", "zero"}},
545 {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}};
546 // clang-format on
547 } else {
548 // dz * z * (x > 0 ? Log(x) : 0)
549 // clang-format off
550 log_x_handling = {
551 {{"pos_x"}, "Greater", {"x", "zero"}},
552 {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}};
553 // clang-format on
554 }
555 nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end());
556 nodes.push_back({{"t4"}, "Mul", {"dz", "z"}});
557 nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}});
558 return GradForBinaryCwise(g, nodes);
559 }
560 REGISTER_OP_GRADIENT("Pow", PowGrad);
561
XlogyGrad(const AttrSlice & attrs,FunctionDef * g)562 Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
563 // clang-format off
564 return GradForBinaryCwise(g, {
565 {{"zeros"}, "ZerosLike", {"x"}},
566 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
567 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
568 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
569 {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
570 {{"xlogygrad"}, "Xdivy", {"x", "y"}},
571 {{"gx"}, "Mul", {"safe_logy", "dz"}},
572 {{"gy"}, "Mul", {"xlogygrad", "dz"}},
573 });
574 // clang-format on
575 }
576 REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
577
XdivyGrad(const AttrSlice & attrs,FunctionDef * g)578 Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
579 // clang-format off
580 return GradForBinaryCwise(g, {
581 {{"zeros"}, "ZerosLike", {"x"}},
582 {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
583 {{"is_zero_cast"}, "Cast", {"is_x_zero"},
584 {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
585 {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
586 {{"y2"}, "Square", {"y"}},
587 {{"negy2"}, "Neg", {"y2"}},
588 {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
589 {{"gx"}, "Mul", {"safe_divy", "dz"}},
590 {{"gy"}, "Mul", {"xdivygrad", "dz"}},
591 });
592 // clang-format on
593 }
594 REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
595
SquaredDifferenceGrad(const AttrSlice & attrs,FunctionDef * g)596 Status SquaredDifferenceGrad(const AttrSlice& attrs, FunctionDef* g) {
597 // clang-format off
598 return GradForBinaryCwise(g, {
599 FDH::Const("c", 2LL),
600 {{"two"}, "Cast", {"c"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
601 {{"x_sub_y"}, "Sub", {"x", "y"}},
602 {{"two_x_sub_y"}, "Mul", {"two", "x_sub_y"}}, // 2 * (x - y)
603 {{"gx"}, "Mul", {"two_x_sub_y", "dz"}},
604 {{"gy"}, "Neg", {"gx"}}
605 });
606 // clang-format on
607 }
608 REGISTER_OP_GRADIENT("SquaredDifference", SquaredDifferenceGrad);
609
MaximumMinimumGradHelper(const string & comparator,const AttrSlice & attrs,FunctionDef * g)610 Status MaximumMinimumGradHelper(const string& comparator,
611 const AttrSlice& attrs, FunctionDef* g) {
612 // clang-format off
613 return GradForBinaryCwise(g, {
614 {{"c"}, comparator, {"x", "y"}, {}, {"dz"}},
615 {{"mask"}, "Cast", {"c"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
616 {{"gx"}, "Mul", {"dz", "mask"}},
617 {{"gy"}, "Sub", {"dz", "gx"}},
618 });
619 // clang-format on
620 }
621
MaximumGrad(const AttrSlice & attrs,FunctionDef * g)622 Status MaximumGrad(const AttrSlice& attrs, FunctionDef* g) {
623 return MaximumMinimumGradHelper("GreaterEqual", attrs, g);
624 }
625 REGISTER_OP_GRADIENT("Maximum", MaximumGrad);
626
MinimumGrad(const AttrSlice & attrs,FunctionDef * g)627 Status MinimumGrad(const AttrSlice& attrs, FunctionDef* g) {
628 return MaximumMinimumGradHelper("LessEqual", attrs, g);
629 }
630 REGISTER_OP_GRADIENT("Minimum", MinimumGrad);
631
ComplexGrad(const AttrSlice & attrs,FunctionDef * g)632 Status ComplexGrad(const AttrSlice& attrs, FunctionDef* g) {
633 // clang-format off
634 return GradForBinaryCwise(g, {
635 {{"gx"}, "Real", {"dz"}},
636 {{"gy"}, "Imag", {"dz"}},
637 });
638 // clang-format on
639 }
640 REGISTER_OP_GRADIENT("Complex", ComplexGrad);
641
642 // Cwise ternary ops.
SelectGrad(const AttrSlice & attrs,FunctionDef * g)643 Status SelectGrad(const AttrSlice& attrs, FunctionDef* g) {
644 // clang-format off
645 *g = FDH::Define(
646 {"c:bool", "x:T", "y:T", "dz:T"},
647 {"dc:bool", "dx:T", "dy:T"},
648 {{"T: {half, float, double}"}},
649 {
650 {{"dc"}, "ZerosLike", {"c"}, {{"T", DT_BOOL}}, {"dz"}},
651 {{"zeros"}, "ZerosLike", {"x"}, {{"T", "$T"}}, {"dz"}},
652 {{"dx"}, "Select", {"c", "dz", "zeros"}, {{"T", "$T"}}},
653 {{"dy"}, "Select", {"c", "zeros", "dz"}, {{"T", "$T"}}},
654 });
655 // clang-format on
656 return Status::OK();
657 }
658 REGISTER_OP_GRADIENT("Select", SelectGrad);
659
660 // N-ry ops
661 // REGISTER_OP_GRADIENT("AddN", AddNGrad);
662
663 // Reduction ops
664 //
665 // TODO(zhifengc): This helper is pretty ugly. Do something better.
666 // TODO(zhifengc): This can be arrange as a function in the standard library.
GradForReductionOp(FunctionDef * g,std::vector<FDH::Node> body)667 Status GradForReductionOp(FunctionDef* g, std::vector<FDH::Node> body) {
668 // Shape manipulation nodes.
669
670 // clang-format off
671 std::vector<FDH::Node> nodes = {
672 {{"x_shape"}, "Shape", {"x"}},
673 {{"x_rank"}, "Rank", {"x"}},
674 {{"i_shape"}, "Shape", {"i"}, {{"T", DT_INT32}}},
675 FDH::Const("zero", 0),
676 FDH::Const("one", 1),
677 // stitch_idx0 = Range(0, x_rank, 1)
678 {{"stitch_val1"}, "Fill", {"i_shape:output:0", "one:output:0"},
679 {{"T", DT_INT32}}},
680 {{"y_shape"}, "DynamicStitch",
681 {"stitch_idx0:output:0", "i",
682 "x_shape:output:0", "stitch_val1:output:0"},
683 {{"N", 2}, {"T", DT_INT32}}},
684 {{"tile_scaling"}, "Div", {"x_shape:output:0", "y_shape:merged:0"},
685 {{"T", DT_INT32}}},
686 {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
687 };
688 // clang-format on
689 nodes.insert(nodes.end(), body.begin(), body.end());
690 for (auto& n : nodes) {
691 if (n.attr.empty()) {
692 n.attr = {{"T", "$T"}};
693 }
694 }
695 // "Range" doesn't need any attr.
696 nodes.push_back({{"stitch_idx0"},
697 "Range",
698 {"zero:output:0", "x_rank:output:0", "one:output:0"},
699 {}});
700 *g = FDH::Create("_",
701 // Input defs
702 {"x:T", "i:int32", "dy:T"},
703 // Ret val defs
704 {"dx:T", "di:int32"},
705 // Attr defs
706 {{"T: {half, float, double}"}},
707 // Nodes
708 nodes,
709 // Return values
710 {{"dx", "dx:output:0"}, {"di", "di:y:0"}});
711 return Status::OK();
712 }
713
SumGrad(const AttrSlice & attrs,FunctionDef * g)714 Status SumGrad(const AttrSlice& attrs, FunctionDef* g) {
715 // clang-format off
716 return GradForReductionOp(g, {
717 {{"dy_reshaped"}, "Reshape", {"dy", "y_shape:merged:0"}},
718 {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
719 });
720 // clang-format on
721 }
722 REGISTER_OP_GRADIENT("Sum", SumGrad);
723
MeanGrad(const AttrSlice & attrs,FunctionDef * g)724 Status MeanGrad(const AttrSlice& attrs, FunctionDef* g) {
725 // clang-format off
726 return GradForReductionOp(g, {
727 {{"factor"}, "Prod", {"tile_scaling:z:0", "zero:output:0"},
728 {{"T", DT_INT32}}},
729 {{"factor_T"}, "Cast", {"factor:output:0"},
730 {{"SrcT", DT_INT32}, {"DstT", "$T"}}},
731 {{"dy_scaled"}, "Div", {"dy", "factor_T:y:0"}},
732 {{"dy_reshaped"}, "Reshape", {"dy_scaled:z:0", "y_shape:merged:0"}},
733 {{"dx"}, "Tile", {"dy_reshaped:output:0", "tile_scaling:z:0"}},
734 });
735 // clang-format on
736 }
737 REGISTER_OP_GRADIENT("Mean", MeanGrad);
738
739 // REGISTER_OP_GRADIENT("Prod", ProdGrad);
740 // REGISTER_OP_GRADIENT("SegmentSum", SegmentSumGrad);
741 // REGISTER_OP_GRADIENT("SegmentMean", SegmentMeanGrad);
742 // REGISTER_OP_GRADIENT("SparseSegmentSum", SparseSegmentSumGrad);
743 // REGISTER_OP_GRADIENT("SparseSegmentMean", SparseSegmentMeanGrad);
744 // REGISTER_OP_GRADIENT("SparseSegmentSqrtN", SparseSegmentSqrtNGrad);
745 // REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
746 // REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
747 // REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
748 // REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
749
MinMaxGradHelper(const string & op,const AttrSlice & attrs,FunctionDef * g)750 Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
751 FunctionDef* g) {
752 // clang-format off
753 *g = FDH::Define(
754 // Arg defs
755 {"x:T", "i:int32", "dy:T"},
756 // Ret val defs
757 {"dx:T", "di:int32"},
758 // Attr defs
759 {{"T: {half, float, double}"}},
760 {
761 // keep_dims because we need to do x == y, which requires x
762 // and y are broadcastable.
763 {{"y"}, op, {"x", "i"}, {{"T", "$T"}, {"keep_dims", true}}},
764 {{"mask"}, "Equal", {"x", "y"}, {{"T", "$T"}}},
765 {{"mask_cast"}, "Cast", {"mask"}, {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
766 {{"mask_sum"}, "Sum", {"mask_cast", "i"}, {{"T", "$T"}}},
767 {{"norm_dy"}, "Div", {"dy", "mask_sum"}, {{"T", "$T"}}},
768 {{"sy"}, "Shape", {"y"}, {{"T", "$T"}}},
769 {{"norm_dy_reshaped"}, "Reshape", {"norm_dy", "sy"}, {{"T", "$T"}}},
770 {{"dx"}, "Mul", {"mask_cast", "norm_dy_reshaped"}, {{"T", "$T"}}},
771 {{"di"}, "ZerosLike", {"i"}, {{"T", DT_INT32}}}
772 });
773 // clang-format on
774 return Status::OK();
775 }
776
MaxGrad(const AttrSlice & attrs,FunctionDef * g)777 Status MaxGrad(const AttrSlice& attrs, FunctionDef* g) {
778 return MinMaxGradHelper("Max", attrs, g);
779 }
780 REGISTER_OP_GRADIENT("Max", MaxGrad);
781
MinGrad(const AttrSlice & attrs,FunctionDef * g)782 Status MinGrad(const AttrSlice& attrs, FunctionDef* g) {
783 return MinMaxGradHelper("Min", attrs, g);
784 }
785 REGISTER_OP_GRADIENT("Min", MinGrad);
786
MatMulGradHelper(FunctionDef * g,const string & opname,const string & attr_adj_x,const string & attr_adj_y,const string & x0,bool ax0,const string & x1,bool ax1,const string & y0,bool ay0,const string & y1,bool ay1)787 static Status MatMulGradHelper(FunctionDef* g, const string& opname,
788 const string& attr_adj_x,
789 const string& attr_adj_y, const string& x0,
790 bool ax0, const string& x1, bool ax1,
791 const string& y0, bool ay0, const string& y1,
792 bool ay1) {
793 *g = FDH::Define(
794 // Arg defs
795 {"x: T", "y: T", "dz: T"},
796 // Ret val defs
797 {"dx: T", "dy: T"},
798 // Attr defs
799 {{"T: {half, float, double}"}},
800 // Nodes
801 {
802 {{"dx"},
803 opname,
804 {x0, x1},
805 {{"T", "$T"}, {attr_adj_x, ax0}, {attr_adj_y, ax1}}},
806 {{"dy"},
807 opname,
808 {y0, y1},
809 {{"T", "$T"}, {attr_adj_x, ay0}, {attr_adj_y, ay1}}},
810 });
811 return Status::OK();
812 }
813
MatMulGradCommon(const string & opname,const string & attr_adj_x,const string & attr_adj_y,const AttrSlice & attrs,FunctionDef * g)814 Status MatMulGradCommon(const string& opname, const string& attr_adj_x,
815 const string& attr_adj_y, const AttrSlice& attrs,
816 FunctionDef* g) {
817 DataType T;
818 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
819 if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
820 return errors::Unimplemented(
821 "MatMul gradient for complex is not supported yet.");
822 }
823 bool ta;
824 bool tb;
825 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_x, &ta));
826 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, attr_adj_y, &tb));
827 if (!ta && !tb) {
828 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
829 true, "x", true, "dz", false);
830 }
831 if (!ta && tb) {
832 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "dz", false, "y",
833 false, "dz", true, "x", false);
834 }
835 if (ta && !tb) {
836 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", false, "dz",
837 true, "x", false, "dz", false);
838 }
839 CHECK(ta && tb);
840 return MatMulGradHelper(g, opname, attr_adj_x, attr_adj_y, "y", true, "dz",
841 true, "dz", true, "x", true);
842 }
843
MatMulGrad(const AttrSlice & attrs,FunctionDef * g)844 Status MatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
845 return MatMulGradCommon("MatMul", "transpose_a", "transpose_b", attrs, g);
846 }
847 REGISTER_OP_GRADIENT("MatMul", MatMulGrad);
848
BatchMatMulGrad(const AttrSlice & attrs,FunctionDef * g)849 Status BatchMatMulGrad(const AttrSlice& attrs, FunctionDef* g) {
850 return MatMulGradCommon("BatchMatMul", "adj_x", "adj_y", attrs, g);
851 }
852 REGISTER_OP_GRADIENT("BatchMatMul", BatchMatMulGrad);
853
854 // REGISTER_OP_GRADIENT("SparseMatMul", SparseMatMulGrad);
855
856 // Comparison ops.
857 REGISTER_OP_NO_GRADIENT("Less");
858 REGISTER_OP_NO_GRADIENT("LessEqual");
859 REGISTER_OP_NO_GRADIENT("Greater");
860 REGISTER_OP_NO_GRADIENT("GreaterEqual");
861 REGISTER_OP_NO_GRADIENT("Equal");
862 REGISTER_OP_NO_GRADIENT("NotEqual");
863
864 // Logical ops.
865 REGISTER_OP_NO_GRADIENT("LogicalAnd");
866 REGISTER_OP_NO_GRADIENT("LogicalOr");
867 REGISTER_OP_NO_GRADIENT("LogicalNot");
868
869 // Sequence generation ops.
870 REGISTER_OP_NO_GRADIENT("Range");
871 REGISTER_OP_NO_GRADIENT("LinSpace");
872
873 REGISTER_OP_NO_GRADIENT("Floor");
874 REGISTER_OP_NO_GRADIENT("FloorDiv");
875 REGISTER_OP_NO_GRADIENT("TruncateDiv");
876
877 } // end namespace tensorflow
878