• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/gradients/grad_helper.h"
21 #include "tensorflow/cc/ops/array_ops.h"
22 #include "tensorflow/cc/ops/array_ops_internal.h"
23 #include "tensorflow/cc/ops/math_ops.h"
24 #include "tensorflow/cc/ops/math_ops_internal.h"
25 #include "tensorflow/cc/ops/standard_ops.h"
26 
27 namespace tensorflow {
28 namespace ops {
29 namespace {
30 
31 // Logical operations have no gradients.
32 REGISTER_NO_GRADIENT_OP("Less");
33 REGISTER_NO_GRADIENT_OP("LessEqual");
34 REGISTER_NO_GRADIENT_OP("Greater");
35 REGISTER_NO_GRADIENT_OP("GreaterEqual");
36 REGISTER_NO_GRADIENT_OP("Equal");
37 REGISTER_NO_GRADIENT_OP("ApproximateEqual");
38 REGISTER_NO_GRADIENT_OP("NotEqual");
39 REGISTER_NO_GRADIENT_OP("LogicalAnd");
40 REGISTER_NO_GRADIENT_OP("LogicalOr");
41 REGISTER_NO_GRADIENT_OP("LogicalNot");
42 REGISTER_NO_GRADIENT_OP("Floor");
43 
44 // Conjugate helper function returns the conjugate of an Output if it
45 // is complex valued.
ConjugateHelper(const Scope & scope,const Output & out)46 Output ConjugateHelper(const Scope& scope, const Output& out) {
47   DataType dtype = out.type();
48   if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
49     return Conj(scope, out);
50   } else {
51     return out;
52   }
53 }
54 
55 // TODO(andydavis) Add control dependencies to gradient functions (as needed).
56 
AbsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)57 Status AbsGrad(const Scope& scope, const Operation& op,
58                const std::vector<Output>& grad_inputs,
59                std::vector<Output>* grad_outputs) {
60   // dx = dy * sign(x)
61   grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
62   return scope.status();
63 }
64 REGISTER_GRADIENT_OP("Abs", AbsGrad);
65 
NegGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)66 Status NegGrad(const Scope& scope, const Operation& op,
67                const std::vector<Output>& grad_inputs,
68                std::vector<Output>* grad_outputs) {
69   // dx = -dy;
70   grad_outputs->push_back(Neg(scope, grad_inputs[0]));
71   return scope.status();
72 }
73 REGISTER_GRADIENT_OP("Neg", NegGrad);
74 
InvGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)75 Status InvGrad(const Scope& scope, const Operation& op,
76                const std::vector<Output>& grad_inputs,
77                std::vector<Output>* grad_outputs) {
78   // Use the built-in operator.
79   grad_outputs->push_back(
80       internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
81   return scope.status();
82 }
83 REGISTER_GRADIENT_OP("Inv", InvGrad);
84 REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
85 
SquareGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)86 Status SquareGrad(const Scope& scope, const Operation& op,
87                   const std::vector<Output>& grad_inputs,
88                   std::vector<Output>* grad_outputs) {
89   // dy/dx = (2 * x)
90   auto two = Cast(scope, Const(scope, 2), op.input(0).type());
91   auto dydx = Mul(scope, two, op.input(0));
92   // grad(x) = grad(y) * conj(dy/dx)
93   grad_outputs->push_back(
94       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
95   return scope.status();
96 }
97 REGISTER_GRADIENT_OP("Square", SquareGrad);
98 
SqrtGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)99 Status SqrtGrad(const Scope& scope, const Operation& op,
100                 const std::vector<Output>& grad_inputs,
101                 std::vector<Output>* grad_outputs) {
102   // Use the built-in operator.
103   grad_outputs->push_back(
104       internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
105   return scope.status();
106 }
107 REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
108 
RsqrtGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)109 Status RsqrtGrad(const Scope& scope, const Operation& op,
110                  const std::vector<Output>& grad_inputs,
111                  std::vector<Output>* grad_outputs) {
112   // Use the built-in operator.
113   grad_outputs->push_back(
114       internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
115   return scope.status();
116 }
117 REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
118 
ExpGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)119 Status ExpGrad(const Scope& scope, const Operation& op,
120                const std::vector<Output>& grad_inputs,
121                std::vector<Output>* grad_outputs) {
122   // dy/dx = exp(x) = y
123   // grad(x) = grad(y) * conj(dy/dx)
124   //         = grad(y) * conj(y)
125   grad_outputs->push_back(
126       Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
127   return scope.status();
128 }
129 REGISTER_GRADIENT_OP("Exp", ExpGrad);
130 
Expm1Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)131 Status Expm1Grad(const Scope& scope, const Operation& op,
132                  const std::vector<Output>& grad_inputs,
133                  std::vector<Output>* grad_outputs) {
134   // y = expm1(x)
135   // dy/dx = exp(x)
136   auto dydx = Exp(scope, op.input(0));
137   // grad(x) = grad(y) * conj(dy/dx)
138   grad_outputs->push_back(
139       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
140   return scope.status();
141 }
142 REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
143 
LogGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)144 Status LogGrad(const Scope& scope, const Operation& op,
145                const std::vector<Output>& grad_inputs,
146                std::vector<Output>* grad_outputs) {
147   // y = log(x)
148   // dy/dx = 1 / x
149   auto dydx = Reciprocal(scope, op.input(0));
150   // grad(x) = grad(y) * conj(dy/dx)
151   grad_outputs->push_back(
152       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
153   return scope.status();
154 }
155 REGISTER_GRADIENT_OP("Log", LogGrad);
156 
Log1pGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)157 Status Log1pGrad(const Scope& scope, const Operation& op,
158                  const std::vector<Output>& grad_inputs,
159                  std::vector<Output>* grad_outputs) {
160   // y = log1p(x)
161   // dy/dx = 1 / (1 + x)
162   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
163   auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
164   // grad(x) = grad(y) * conj(dy/dx)
165   grad_outputs->push_back(
166       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
167   return scope.status();
168 }
169 REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
170 
SinhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)171 Status SinhGrad(const Scope& scope, const Operation& op,
172                 const std::vector<Output>& grad_inputs,
173                 std::vector<Output>* grad_outputs) {
174   // y = sinh(x)
175   // dy/dx = cosh(x)
176   auto dydx = Cosh(scope, op.input(0));
177   // grad(x) = grad(y) * conj(dy/dx)
178   grad_outputs->push_back(
179       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
180   return scope.status();
181 }
182 REGISTER_GRADIENT_OP("Sinh", SinhGrad);
183 
CoshGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)184 Status CoshGrad(const Scope& scope, const Operation& op,
185                 const std::vector<Output>& grad_inputs,
186                 std::vector<Output>* grad_outputs) {
187   // y = cosh(x)
188   // dy/dx = sinh(x)
189   auto dydx = Sinh(scope, op.input(0));
190   // grad(x) = grad(y) * conj(dy/dx)
191   grad_outputs->push_back(
192       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
193   return scope.status();
194 }
195 REGISTER_GRADIENT_OP("Cosh", CoshGrad);
196 
TanhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)197 Status TanhGrad(const Scope& scope, const Operation& op,
198                 const std::vector<Output>& grad_inputs,
199                 std::vector<Output>* grad_outputs) {
200   // Use the built-in operator.
201   // Note that the built-in operator does not return the conjugate of
202   // the gradient.
203   auto grad = grad_inputs[0];
204   // Optimization to avoid calculating conj(y) until the gradient is
205   // evaluated.
206   Scope grad_scope = scope.WithControlDependencies(grad);
207   auto y = ConjugateHelper(grad_scope, op.output(0));
208   grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
209   return grad_scope.status();
210 }
211 REGISTER_GRADIENT_OP("Tanh", TanhGrad);
212 
AsinhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status AsinhGrad(const Scope& scope, const Operation& op,
214                  const std::vector<Output>& grad_inputs,
215                  std::vector<Output>* grad_outputs) {
216   // y = asinh(x)
217   // dy/dx = 1 / cosh(y)
218   auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
219   // grad(x) = grad(y) * conj(dy/dx)
220   grad_outputs->push_back(
221       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
222   return scope.status();
223 }
224 REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
225 
AcoshGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)226 Status AcoshGrad(const Scope& scope, const Operation& op,
227                  const std::vector<Output>& grad_inputs,
228                  std::vector<Output>* grad_outputs) {
229   // y = acosh(x)
230   // dy/dx = 1 / sinh(y)
231   auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
232   // grad(x) = grad(y) * conj(dy/dx)
233   grad_outputs->push_back(
234       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
235   return scope.status();
236 }
237 REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
238 
AtanhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)239 Status AtanhGrad(const Scope& scope, const Operation& op,
240                  const std::vector<Output>& grad_inputs,
241                  std::vector<Output>* grad_outputs) {
242   // y = atanh(x)
243   // dy/dx = 1 / (1 - x^2)
244   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
245   auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
246   // grad(x) = grad(y) * conj(dy/dx)
247   grad_outputs->push_back(
248       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
249   return scope.status();
250 }
251 REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
252 
SigmoidGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)253 Status SigmoidGrad(const Scope& scope, const Operation& op,
254                    const std::vector<Output>& grad_inputs,
255                    std::vector<Output>* grad_outputs) {
256   // Use the built-in operator.
257   // Note that the built-in operator does not return the conjugate of
258   // the gradient.
259   auto grad = grad_inputs[0];
260   // Optimization to avoid calculating conj(y) until the gradient is
261   // evaluated.
262   Scope grad_scope = scope.WithControlDependencies(grad);
263   auto y = ConjugateHelper(grad_scope, op.output(0));
264   grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
265   return grad_scope.status();
266 }
267 REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
268 
SignGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)269 Status SignGrad(const Scope& scope, const Operation& op,
270                 const std::vector<Output>& grad_inputs,
271                 std::vector<Output>* grad_outputs) {
272   auto shape = Shape(scope, op.input(0));
273   auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
274   auto dx = Fill(scope, shape, zero);
275   grad_outputs->push_back(dx);
276   return scope.status();
277 }
278 REGISTER_GRADIENT_OP("Sign", SignGrad);
279 
SinGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)280 Status SinGrad(const Scope& scope, const Operation& op,
281                const std::vector<Output>& grad_inputs,
282                std::vector<Output>* grad_outputs) {
283   // y = sin(x)
284   // dy/dx = cos(x)
285   auto dydx = Cos(scope, op.input(0));
286   // grad(x) = grad(y) * conj(dy/dx)
287   grad_outputs->push_back(
288       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
289   return scope.status();
290 }
291 REGISTER_GRADIENT_OP("Sin", SinGrad);
292 
CosGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)293 Status CosGrad(const Scope& scope, const Operation& op,
294                const std::vector<Output>& grad_inputs,
295                std::vector<Output>* grad_outputs) {
296   // y = cos(x)
297   // dy/dx = -sin(x)
298   auto dydx = Neg(scope, Sin(scope, op.input(0)));
299   // grad(x) = grad(y) * conj(dy/dx)
300   grad_outputs->push_back(
301       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
302   return scope.status();
303 }
304 REGISTER_GRADIENT_OP("Cos", CosGrad);
305 
AsinGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)306 Status AsinGrad(const Scope& scope, const Operation& op,
307                 const std::vector<Output>& grad_inputs,
308                 std::vector<Output>* grad_outputs) {
309   // y = asin(x)
310   // dy/dx = 1 / sqrt(1 - x^2)
311   auto x2 = Square(scope, op.input(0));
312   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
313   auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
314   // grad(x) = grad(y) * conj(dy/dx)
315   auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
316   grad_outputs->push_back(dx);
317   return scope.status();
318 }
319 REGISTER_GRADIENT_OP("Asin", AsinGrad);
320 
AcosGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)321 Status AcosGrad(const Scope& scope, const Operation& op,
322                 const std::vector<Output>& grad_inputs,
323                 std::vector<Output>* grad_outputs) {
324   // y = acos(x)
325   // dy/dx = - 1 / (1 - x * x)^1/2
326   // dx = dy * (- 1 / (1 - x * x)^1/2)
327   auto x2 = Square(scope, op.input(0));
328   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
329   auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
330   auto dx = Mul(scope, grad_inputs[0], dydx);
331   grad_outputs->push_back(dx);
332   return scope.status();
333 }
334 REGISTER_GRADIENT_OP("Acos", AcosGrad);
335 
TanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)336 Status TanGrad(const Scope& scope, const Operation& op,
337                const std::vector<Output>& grad_inputs,
338                std::vector<Output>* grad_outputs) {
339   // y = tan(x)
340   // dy/dx = sec(x)^2 = 1 / cos(x)^2
341   auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
342   // grad(x) = grad(y) * conj(dy/dx)
343   auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
344   grad_outputs->push_back(dx);
345   return scope.status();
346 }
347 REGISTER_GRADIENT_OP("Tan", TanGrad);
348 
AtanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)349 Status AtanGrad(const Scope& scope, const Operation& op,
350                 const std::vector<Output>& grad_inputs,
351                 std::vector<Output>* grad_outputs) {
352   // y = arctan(x)
353   // dy/dx = 1 / (1 + x^2)
354   // dx = dy * (1 / (1 + x^2)
355   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
356   auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
357   auto dx = Mul(scope, grad_inputs[0], dydx);
358   grad_outputs->push_back(dx);
359   return scope.status();
360 }
361 REGISTER_GRADIENT_OP("Atan", AtanGrad);
362 
Atan2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)363 Status Atan2Grad(const Scope& scope, const Operation& op,
364                  const std::vector<Output>& grad_inputs,
365                  std::vector<Output>* grad_outputs) {
366   auto y = op.input(0);
367   auto x = op.input(1);
368   Output grad_inv = Div(scope, grad_inputs[0],
369                         Add(scope, Square(scope, x), Square(scope, y)));
370   grad_outputs->push_back(Mul(scope, x, grad_inv));
371   grad_outputs->push_back(Mul(scope, Neg(scope, y), grad_inv));
372   return scope.status();
373 }
374 REGISTER_GRADIENT_OP("Atan2", Atan2Grad);
375 
376 // BinaryGradCommon handles the setup for binary ops that broadcast
377 // their inputs.
BinaryGradCommon(const Scope & scope,const Operation & op,std::vector<Output> * grad_outputs,const Output & gx_1,const Output & gx_2)378 Status BinaryGradCommon(const Scope& scope, const Operation& op,
379                         std::vector<Output>* grad_outputs, const Output& gx_1,
380                         const Output& gx_2) {
381   auto sx_1 = Shape(scope, op.input(0));
382   auto sx_2 = Shape(scope, op.input(1));
383   auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
384   auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
385   auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
386   grad_outputs->push_back(dx_1);
387   grad_outputs->push_back(dx_2);
388   return scope.status();
389 }
390 
AddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)391 Status AddGrad(const Scope& scope, const Operation& op,
392                const std::vector<Output>& grad_inputs,
393                std::vector<Output>* grad_outputs) {
394   // y = x_1 + x_2
395   // dy/dx_1 = dy/dx_2 = 1
396   auto gx_1 = Identity(scope, grad_inputs[0]);
397   auto gx_2 = Identity(scope, grad_inputs[0]);
398   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
399 }
400 REGISTER_GRADIENT_OP("Add", AddGrad);
401 REGISTER_GRADIENT_OP("AddV2", AddGrad);
402 
SubGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)403 Status SubGrad(const Scope& scope, const Operation& op,
404                const std::vector<Output>& grad_inputs,
405                std::vector<Output>* grad_outputs) {
406   // y = x_1 - x_2
407   // dy/dx_1 = 1
408   // dy/dx_2 = -1
409   auto gx_1 = Identity(scope, grad_inputs[0]);
410   auto gx_2 = Neg(scope, grad_inputs[0]);
411   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
412 }
413 REGISTER_GRADIENT_OP("Sub", SubGrad);
414 
MulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)415 Status MulGrad(const Scope& scope, const Operation& op,
416                const std::vector<Output>& grad_inputs,
417                std::vector<Output>* grad_outputs) {
418   auto x_1 = ConjugateHelper(scope, op.input(0));
419   auto x_2 = ConjugateHelper(scope, op.input(1));
420   // y = x_1 * x_2
421   // dy/dx_1 = x_2
422   // dy/dx_2 = x_1
423   auto gx_1 = Mul(scope, grad_inputs[0], x_2);
424   auto gx_2 = Mul(scope, grad_inputs[0], x_1);
425   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
426 }
427 REGISTER_GRADIENT_OP("Mul", MulGrad);
428 
DivGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)429 Status DivGrad(const Scope& scope, const Operation& op,
430                const std::vector<Output>& grad_inputs,
431                std::vector<Output>* grad_outputs) {
432   auto x_1 = ConjugateHelper(scope, op.input(0));
433   auto x_2 = ConjugateHelper(scope, op.input(1));
434   // y = x_1 / x_2
435   // dy/dx_1 = 1/x_2
436   // dy/dx_2 = -x_1/x_2^2
437   auto gx_1 = Div(scope, grad_inputs[0], x_2);
438   auto gx_2 = Mul(scope, grad_inputs[0],
439                   Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
440   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
441 }
442 REGISTER_GRADIENT_OP("Div", DivGrad);
443 
RealDivGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)444 Status RealDivGrad(const Scope& scope, const Operation& op,
445                    const std::vector<Output>& grad_inputs,
446                    std::vector<Output>* grad_outputs) {
447   auto x_1 = ConjugateHelper(scope, op.input(0));
448   auto x_2 = ConjugateHelper(scope, op.input(1));
449   // y = x_1 / x_2
450   // dy/dx_1 = 1/x_2
451   // dy/dx_2 = -x_1/x_2^2
452   auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
453   auto gx_2 = Mul(scope, grad_inputs[0],
454                   RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
455   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
456 }
457 REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
458 
DivNoNanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)459 Status DivNoNanGrad(const Scope& scope, const Operation& op,
460                     const std::vector<Output>& grad_inputs,
461                     std::vector<Output>* grad_outputs) {
462   auto x_1 = ConjugateHelper(scope, op.input(0));
463   auto x_2 = ConjugateHelper(scope, op.input(1));
464   // y = x_1 / x_2
465   // dy/dx_1 = 1/x_2
466   // dy/dx_2 = -x_1/x_2^2
467   auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
468   auto gx_2 = Mul(scope, grad_inputs[0],
469                   DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
470   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
471 }
472 REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
473 
SquaredDifferenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)474 Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
475                              const std::vector<Output>& grad_inputs,
476                              std::vector<Output>* grad_outputs) {
477   auto x_1 = ConjugateHelper(scope, op.input(0));
478   auto x_2 = ConjugateHelper(scope, op.input(1));
479   // y = (x_1 - x_2)^2
480   // dy/dx_1 = 2 * (x_1 - x_2)
481   // dy/dx_2 = -2 * (x_1 - x_2)
482   auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
483   auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
484   auto gx_2 = Neg(scope, gx_1);
485   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
486 }
487 REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
488 
AddNGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)489 Status AddNGrad(const Scope& scope, const Operation& op,
490                 const std::vector<Output>& grad_inputs,
491                 std::vector<Output>* grad_outputs) {
492   // AddN doesn't support broadcasting, so all the inputs must be the
493   // same shape.
494   // Note:
495   // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
496   // hence dx_k = dy for all x_k
497   // So the gradient for AddN just transfers the incoming gradient to
498   // all outgoing gradients.
499   auto incoming = Identity(scope, grad_inputs[0]);
500   for (int32_t i = 0; i < op.num_inputs(); ++i) {
501     grad_outputs->push_back(incoming);
502   }
503   return scope.status();
504 }
505 REGISTER_GRADIENT_OP("AddN", AddNGrad);
506 
PowGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)507 Status PowGrad(const Scope& scope, const Operation& op,
508                const std::vector<Output>& grad_inputs,
509                std::vector<Output>* grad_outputs) {
510   auto x = ConjugateHelper(scope, op.input(0));
511   auto y = ConjugateHelper(scope, op.input(1));
512   auto z = ConjugateHelper(scope, op.output(0));
513   auto grad = grad_inputs[0];
514   // grad * y * pow(x, y - 1)
515   auto one = Cast(scope, Const(scope, 1.0), y.type());
516   auto gx_1 =
517       Mul(scope, Mul(scope, grad, y), Pow(scope, x, Sub(scope, y, one)));
518   // Avoid false singularity at x = 0
519   DataType x_dtype = x.type();
520   auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
521   if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
522     // real(x) < 0 is fine for the complex case
523     auto log_x = Where3(scope, NotEqual(scope, x, zero), Log(scope, x),
524                         ZerosLike(scope, x));
525     auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
526     return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
527   } else {
528     // There's no sensible real value to return if x < 0, so return 0
529     auto log_x = Where3(scope, Greater(scope, x, zero), Log(scope, x),
530                         ZerosLike(scope, x));
531     auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
532     return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
533   }
534 }
535 REGISTER_GRADIENT_OP("Pow", PowGrad);
536 
537 // MaximumMinimumGradCommon adds shared ops to calculate gradients for
538 // the binary Maximum and Minimum ops.
MaximumMinimumGradCommon(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs,const Output & comparator)539 Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
540                                 const std::vector<Output>& grad_inputs,
541                                 std::vector<Output>* grad_outputs,
542                                 const Output& comparator) {
543   // comparator is a boolean tensor, with
544   // y = x_1 at points where comparator is true, and x_2 otherwise
545   // Therefore
546   // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
547   // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
548   auto grad = grad_inputs[0];
549   auto zeros = ZerosLike(scope, grad);
550   auto gx_1 = Where3(scope, comparator, grad, zeros);
551   auto gx_2 = Where3(scope, comparator, zeros, grad);
552   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
553 }
554 
MaximumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)555 Status MaximumGrad(const Scope& scope, const Operation& op,
556                    const std::vector<Output>& grad_inputs,
557                    std::vector<Output>* grad_outputs) {
558   auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
559   return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
560                                   comparator);
561 }
562 REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
563 
MinimumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)564 Status MinimumGrad(const Scope& scope, const Operation& op,
565                    const std::vector<Output>& grad_inputs,
566                    std::vector<Output>* grad_outputs) {
567   auto comparator = LessEqual(scope, op.input(0), op.input(1));
568   return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
569                                   comparator);
570 }
571 REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
572 
RealGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)573 Status RealGrad(const Scope& scope, const Operation& op,
574                 const std::vector<Output>& grad_inputs,
575                 std::vector<Output>* grad_outputs) {
576   auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
577   auto dx = Complex(scope, grad_inputs[0], zero);
578   grad_outputs->push_back(dx);
579   return scope.status();
580 }
581 REGISTER_GRADIENT_OP("Real", RealGrad);
582 
ImagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)583 Status ImagGrad(const Scope& scope, const Operation& op,
584                 const std::vector<Output>& grad_inputs,
585                 std::vector<Output>* grad_outputs) {
586   auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
587   auto dx = Complex(scope, zero, grad_inputs[0]);
588   grad_outputs->push_back(dx);
589   return scope.status();
590 }
591 REGISTER_GRADIENT_OP("Imag", ImagGrad);
592 
ComplexGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)593 Status ComplexGrad(const Scope& scope, const Operation& op,
594                    const std::vector<Output>& grad_inputs,
595                    std::vector<Output>* grad_outputs) {
596   auto gx_1 = Real(scope, grad_inputs[0]);
597   auto gx_2 = Imag(scope, grad_inputs[0]);
598   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
599 }
600 REGISTER_GRADIENT_OP("Complex", ComplexGrad);
601 
AngleGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)602 Status AngleGrad(const Scope& scope, const Operation& op,
603                  const std::vector<Output>& grad_inputs,
604                  std::vector<Output>* grad_outputs) {
605   // y = Angle(x)
606   // dx = -dy / (Im(x) + iRe(x)) = -dy * z
607   auto re = Real(scope, op.input(0));
608   auto im = Imag(scope, op.input(0));
609   auto z_inv = Reciprocal(scope, Complex(scope, im, re));
610   auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
611   auto grad = Complex(scope, grad_inputs[0], zero);
612   auto dx = Neg(scope, Mul(scope, grad, z_inv));
613   grad_outputs->push_back(dx);
614   return scope.status();
615 }
616 REGISTER_GRADIENT_OP("Angle", AngleGrad);
617 
ConjGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)618 Status ConjGrad(const Scope& scope, const Operation& op,
619                 const std::vector<Output>& grad_inputs,
620                 std::vector<Output>* grad_outputs) {
621   grad_outputs->push_back(Conj(scope, grad_inputs[0]));
622   return scope.status();
623 }
624 REGISTER_GRADIENT_OP("Conj", ConjGrad);
625 
626 // Integer division x / y, assuming x and y >=0, but treats x/0 = x
SafeDivHelper(const Scope & scope,const Output & x,const Output & y)627 Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
628   return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
629 }
630 
631 // SumGradHelper returns the gradient for the Sum operator, and is used
632 // by SumGrad and MeanGrad.
SumGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs)633 Output SumGradHelper(const Scope& scope, const Operation& op,
634                      const std::vector<Output>& grad_inputs) {
635   // The partial derivative for any input along a "reduced" dimension
636   // is just 1, so we only need replicate the output gradient on such a
637   // dimension to its "expanded" shape.
638   // Running example:
639   // input is
640   // [[a, b, c],
641   //  [d, e, f]]
642   // reduction_indices = [1]
643   // Sum = [a + b + c, d + e + f]
644   // if the gradient is [g1, g2]
645   // We want the propagated gradient to be
646   // [[g1, g1, g1],
647   //  [g2, g2, g2]]
648 
649   // input_shape = [2, 3]
650   auto input_shape = Shape(scope, op.input(0));
651 
652   // output_shape_kept_dims = [2, 1]
653   auto output_shape_kept_dims =
654       ReducedShapeHelper(scope, input_shape, op.input(1));
655 
656   // This step "flips" any 1s with values from the input_shape, and
657   // replaces remaining entries with 1. This creates a shape that
658   // shows how much each dimension in the incoming gradient should be
659   // replicated.
660   // tile_scaling = [1, 3]
661   auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
662 
663   // grad = [[g1], [g2]]
664   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
665 
666   // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
667   return Tile(scope, grad, tile_scaling);
668 }
669 
SumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)670 Status SumGrad(const Scope& scope, const Operation& op,
671                const std::vector<Output>& grad_inputs,
672                std::vector<Output>* grad_outputs) {
673   grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
674 
675   // Stop propagation along reduction_indices
676   grad_outputs->push_back(NoGradient());
677   return scope.status();
678 }
679 REGISTER_GRADIENT_OP("Sum", SumGrad);
680 
MeanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)681 Status MeanGrad(const Scope& scope, const Operation& op,
682                 const std::vector<Output>& grad_inputs,
683                 std::vector<Output>* grad_outputs) {
684   // The Mean gradient is just like the Sum gradient, except that
685   // all gradients are also divided by the size of reduced groups.
686   auto sum_grad = SumGradHelper(scope, op, grad_inputs);
687 
688   // The product of all entries in a tensor's shape is the total
689   // number of entries in the tensor. This step calculates
690   // n_input_entries/n_output_entries
691   // = group_size
692   auto input_shape = Shape(scope, op.input(0));
693   auto output_shape = Shape(scope, op.output(0));
694   auto zero = Const(scope, 0);
695   auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
696                                   Prod(scope, output_shape, zero));
697 
698   // propagate sum_grad/group_size
699   grad_outputs->push_back(
700       Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
701 
702   // Stop propagation along reduction_indices
703   grad_outputs->push_back(NoGradient());
704   return scope.status();
705 }
706 REGISTER_GRADIENT_OP("Mean", MeanGrad);
707 
ErfGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)708 Status ErfGrad(const Scope& scope, const Operation& op,
709                const std::vector<Output>& grad_inputs,
710                std::vector<Output>* grad_outputs) {
711   auto grad = grad_inputs[0];
712   auto two_over_root_pi =
713       Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), grad.type());
714   Scope grad_scope = scope.WithControlDependencies(grad);
715   auto x = ConjugateHelper(grad_scope, op.input(0));
716   // grad * 2/sqrt(pi) * exp(-x**2)
717   auto dx = Mul(grad_scope, Mul(grad_scope, grad, two_over_root_pi),
718                 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
719   grad_outputs->push_back(dx);
720   return grad_scope.status();
721 }
722 REGISTER_GRADIENT_OP("Erf", ErfGrad);
723 
ErfinvGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)724 Status ErfinvGrad(const Scope& scope, const Operation& op,
725                   const std::vector<Output>& grad_inputs,
726                   std::vector<Output>* grad_outputs) {
727   auto grad = grad_inputs[0];
728   auto root_pi_over_two =
729       Cast(scope, Const(scope, std::sqrt(M_PI) / 2), grad.type());
730   Scope grad_scope = scope.WithControlDependencies(grad);
731   auto x = ConjugateHelper(grad_scope, op.input(0));
732   // grad * sqrt(pi) / 2 * exp(erfinv(x) ** 2)
733   auto dx = Mul(grad_scope, Mul(grad_scope, grad, root_pi_over_two),
734                 Exp(grad_scope, Square(grad_scope, op.output(0))));
735   grad_outputs->push_back(dx);
736   return grad_scope.status();
737 }
738 REGISTER_GRADIENT_OP("Erfinv", ErfinvGrad);
739 
NdtriGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)740 Status NdtriGrad(const Scope& scope, const Operation& op,
741                  const std::vector<Output>& grad_inputs,
742                  std::vector<Output>* grad_outputs) {
743   auto grad = grad_inputs[0];
744   auto root_two_pi =
745       Cast(scope, Const(scope, std::sqrt(2 * M_PI)), grad.type());
746   auto two = Cast(scope, Const(scope, 2), grad.type());
747   Scope grad_scope = scope.WithControlDependencies(grad);
748   auto x = ConjugateHelper(grad_scope, op.input(0));
749   // grad * sqrt(2 * pi) * exp(ndtri(x) ** 2 / 2)
750   auto dx = Mul(
751       grad_scope, Mul(grad_scope, grad, root_two_pi),
752       Exp(grad_scope, Div(grad_scope, Square(grad_scope, op.output(0)), two)));
753   grad_outputs->push_back(dx);
754   return grad_scope.status();
755 }
756 REGISTER_GRADIENT_OP("Ndtri", NdtriGrad);
757 
LgammaGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)758 Status LgammaGrad(const Scope& scope, const Operation& op,
759                   const std::vector<Output>& grad_inputs,
760                   std::vector<Output>* grad_outputs) {
761   auto grad = grad_inputs[0];
762   Scope grad_scope = scope.WithControlDependencies(grad);
763   auto x = ConjugateHelper(grad_scope, op.input(0));
764   auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
765   grad_outputs->push_back(dx);
766   return grad_scope.status();
767 }
768 REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
769 
MinOrMaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)770 Status MinOrMaxGrad(const Scope& scope, const Operation& op,
771                     const std::vector<Output>& grad_inputs,
772                     std::vector<Output>* grad_outputs) {
773   // The partial derivative for any input along a "reduced" dimension
774   // is 1 when it is the min (or max) and 0 everywhere else. So the
775   // gradient calculation is identical for both operators.
776   //
777   // There's a special case for propagating gradients when there are
778   // multiple minima (or maxima) - we choose to divide the gradient
779   // equally among all matching inputs.
780   //
781   // Please note this comment
782   // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
783   // for details.
784 
785   // Running example:
786   // input: [[5, 5, 5],
787   //         [1, 2, -3]]
788   // reduction_indices: [1]
789   auto input = op.input(0);
790   auto reduction_indices = op.input(1);
791 
792   // [2, 3]
793   auto input_shape = Shape(scope, input);
794 
795   // [2, 1]
796   auto output_shape_kept_dims =
797       ReducedShapeHelper(scope, input_shape, reduction_indices);
798 
799   // for op=min (say)
800   // output = [5, -3]
801   // y = [[5],
802   //      [-3]]
803   auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
804 
805   // reshape([g1, g2], [2, 1]) = [[g1],
806   //                              [g2]]
807   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
808 
809   // indicators = equal(y, input)
810   //  = equal([[5],   [[5, 5, 5],
811   //           [-3]],  [1, 2, -3]])
812   //  = [[1, 1, 1],
813   //     [0, 0, 1]]
814   auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
815 
816   // [[3],
817   //  [1]]
818   auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
819                               output_shape_kept_dims);
820 
821   // [[1/3, 1/3, 1/3],
822   //  [0, 0, 1]]
823   auto scale = Div(scope, indicators, num_selected);
824 
825   // [[g1/3, g1/3, g1/3],
826   //  [0, 0, g2]]
827   grad_outputs->push_back(Mul(scope, scale, grad));
828 
829   // Stop propagation along reduction_indices
830   grad_outputs->push_back(NoGradient());
831   return scope.status();
832 }
833 REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
834 REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
835 
ProdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)836 Status ProdGrad(const Scope& scope, const Operation& op,
837                 const std::vector<Output>& grad_inputs,
838                 std::vector<Output>* grad_outputs) {
839   auto zero = Const(scope, 0);
840   auto one = Const(scope, 1);
841 
842   // The gradient can be expressed by dividing the product by each entry of
843   // the input tensor. If our input is
844   // [
845   //  [3, 4],
846   //  [5, 6],
847   //  [7, 8]
848   // ]
849   // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
850   // The gradient will have the same shape as the input
851   //     [
852   //       [105/3, 192/4],
853   // dz *  [105/5, 192/6],
854   //       [105/7, 192/6]
855   //     ]
856   // If the input contains a zero, the division is impossible but
857   // if we take the calculation that gave the first gradient
858   // (3 * 5 * 6)/3 is equal to 5 * 6
859   // the trick will be to cumprod the elements on the axis without
860   // the element at the current position (3 in the example above).
861   // We will take as example:
862   // [
863   //   [
864   //     [3.0, 4.0],
865   //     [5.0, 6.0],
866   //     [7.0, 8.0]
867   //   ],
868   //   [
869   //     [3.0, 5.0],
870   //     [0.0, 6.0],
871   //     [5.0, 6.0]
872   //   ]
873   // ]
874 
875   // [2, 3, 2]
876   auto input_shape = Shape(scope, op.input(0));
877 
878   // The Reshape with -1 flattens the reduction indices.
879   // [1]
880   auto reduction_indices = Reshape(scope, op.input(1), {-1});
881 
882   // [2, 1, 2]
883   auto output_shape_kept_dims =
884       ReducedShapeHelper(scope, input_shape, reduction_indices);
885 
886   // [1, 3, 1]
887   auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
888 
889   // [[[105, 192]], [[0, 180]]]
890   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
891 
892   // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
893   auto grad_tiled = Tile(scope, grad, tile_scaling);
894 
895   Scope cpu_scope = scope.WithDevice("/cpu:0");
896 
897   // [3]
898   auto rank = Rank(cpu_scope, op.input(0));
899 
900   // Normalize any negative indices in the reduction_axes to positive values.
901   auto reduction_indices_pos =
902       Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
903 
904   // [1]
905   auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
906 
907   // [0, 1, 2]
908   auto idx = Range(cpu_scope, zero, rank, one);
909 
910   // [0, 2]
911   auto other = SetDiff1D(cpu_scope, idx, reduced).out;
912 
913   // [1, 0, 2]
914   auto perm =
915       Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
916 
917   // 3 => [3]
918   auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
919 
920   // 2 * 2 => [2]
921   auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
922 
923   // [
924   //    [
925   //       [ 3.,  4.],
926   //       [ 3.,  5.]
927   //   ],
928   //   [
929   //       [ 5.,  6.],
930   //       [ 0.,  6.]
931   //   ],
932   //   [
933   //       [ 7.,  8.],
934   //       [ 5.,  6.]
935   //   ]
936   // ]
937   auto permuted = Transpose(scope, op.input(0), perm);
938 
939   // [3, 2, 2]
940   auto permuted_shape = Shape(scope, permuted);
941 
942   // [
943   //   [ 3.,  4.,  3.,  5.],
944   //   [ 5.,  6.,  0.,  6.],
945   //   [ 7.,  8.,  5.,  6.]
946   // ]
947   auto reshaped = Reshape(
948       scope, permuted,
949       Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
950 
951   // [
952   //   [ 1.,  1.,  1.,  1.],
953   //   [ 3.,  4.,  3.,  5.],
954   //   [ 15.,  24.,  0.,  30.]
955   // ]
956   auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
957 
958   // [
959   //   [ 35.,  48.,  0.,  36.],
960   //   [  7.,   8.,   5.,   6.],
961   //   [  1.,   1.,   1.,   1.]
962   // ]
963   auto right =
964       Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
965 
966   // left * right =
967   // [
968   //   [ 35.,  48.,  0.,  36.],
969   //   [ 21.,  32.,  15.,  30.],
970   //   [ 15.,  24.,  0.,  30.]
971   // ]
972   // y =
973   // [
974   //   [
975   //     [ 35.,  48.],
976   //     [ 0.,  36.]
977   //   ],
978   //   [
979   //     [ 21.,  32.],
980   //     [ 15.,  30.]
981   //   ],
982   //   [
983   //     [ 15.,  24.],
984   //     [ 0.,  30.]
985   //   ]
986   // ]
987   auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
988 
989   // out =
990   // [
991   //   [
992   //     [ 35.,  48.],
993   //     [ 21.,  32.],
994   //     [ 15.,  24.]
995   //   ],
996   //   [
997   //     [ 0.,   36.],
998   //     [ 15.,  30.],
999   //     [ 0.,  30.]
1000   //   ]
1001   // ]
1002   auto out = Mul(scope, grad_tiled,
1003                  Transpose(scope, y, InvertPermutation(scope, perm)));
1004 
1005   grad_outputs->push_back(Reshape(scope, out, input_shape));
1006 
1007   // stop propagation along reduction_indices
1008   grad_outputs->push_back(NoGradient());
1009   return scope.status();
1010 }
1011 REGISTER_GRADIENT_OP("Prod", ProdGrad);
1012 
SegmentSumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1013 Status SegmentSumGrad(const Scope& scope, const Operation& op,
1014                       const std::vector<Output>& grad_inputs,
1015                       std::vector<Output>* grad_outputs) {
1016   // The SegmentSum operation sums segments of the Tensor that have the same
1017   // index in the segment_ids parameter.
1018   // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
1019   // will produce [2 + 3 + 4, 5] = [9, 5]
1020   // The gradient that will flow back to the gather operation will look like
1021   // [x1, x2], it will have the same shape as the output of the SegmentSum
1022   // operation. The differentiation step of the SegmentSum operation just
1023   // broadcast the gradient in order to retrieve the z's shape.
1024   // dy/dz = [x1, x1, x1, x2]
1025   grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
1026 
1027   // stop propagation along segment_ids
1028   grad_outputs->push_back(NoGradient());
1029   return scope.status();
1030 }
1031 REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
1032 
1033 // MatMulGrad helper function used to compute two MatMul operations
1034 // based on input matrix transposition combinations.
MatMulGradHelper(const Scope & scope,const bool is_batch,const Output & x0,const bool adj_x0,const Output & x1,const bool adj_x1,const Output & y0,const bool adj_y0,const Output & y1,const bool adj_y1,std::vector<Output> * grad_outputs)1035 Status MatMulGradHelper(const Scope& scope, const bool is_batch,
1036                         const Output& x0, const bool adj_x0, const Output& x1,
1037                         const bool adj_x1, const Output& y0, const bool adj_y0,
1038                         const Output& y1, const bool adj_y1,
1039                         std::vector<Output>* grad_outputs) {
1040   if (is_batch == false) {
1041     auto dx =
1042         MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
1043     grad_outputs->push_back(dx);
1044     auto dy =
1045         MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
1046     grad_outputs->push_back(dy);
1047   } else {
1048     auto dx =
1049         BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
1050     grad_outputs->push_back(dx);
1051     auto dy =
1052         BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
1053     grad_outputs->push_back(dy);
1054   }
1055   return scope.status();
1056 }
1057 
1058 // MatMulGrad common used to read and check node attr state, and determine
1059 // proper MatMul products for gradients based on input matrix transposition
1060 // combinations.
MatMulGradCommon(const Scope & scope,const Operation & op,const bool is_batch,const std::vector<Output> & grad_inputs,const string & attr_adj_x,const string & attr_adj_y,std::vector<Output> * grad_outputs)1061 Status MatMulGradCommon(const Scope& scope, const Operation& op,
1062                         const bool is_batch,
1063                         const std::vector<Output>& grad_inputs,
1064                         const string& attr_adj_x, const string& attr_adj_y,
1065                         std::vector<Output>* grad_outputs) {
1066   auto a = op.input(0);
1067   auto b = op.input(1);
1068   // Use conjugate of the inputs for MatMul
1069   if (is_batch == false) {
1070     a = ConjugateHelper(scope, a);
1071     b = ConjugateHelper(scope, b);
1072   }
1073   auto product = op.output(0);
1074 
1075   bool ta;
1076   bool tb;
1077   TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
1078   TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
1079 
1080   if (!ta && !tb) {
1081     return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
1082                             true, grad_inputs[0], false, grad_outputs);
1083   } else if (!ta && tb) {
1084     return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
1085                             grad_inputs[0], true, a, false, grad_outputs);
1086   } else if (ta && !tb) {
1087     return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
1088                             false, grad_inputs[0], false, grad_outputs);
1089   }
1090   return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
1091                           grad_inputs[0], true, a, true, grad_outputs);
1092 }
1093 
MatMulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1094 Status MatMulGrad(const Scope& scope, const Operation& op,
1095                   const std::vector<Output>& grad_inputs,
1096                   std::vector<Output>* grad_outputs) {
1097   return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
1098                           "transpose_b", grad_outputs);
1099 }
1100 REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
1101 
BatchMatMulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1102 Status BatchMatMulGrad(const Scope& scope, const Operation& op,
1103                        const std::vector<Output>& grad_inputs,
1104                        std::vector<Output>* grad_outputs) {
1105   return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
1106                           grad_outputs);
1107 }
1108 REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
1109 
CumsumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1110 Status CumsumGrad(const Scope& scope, const Operation& op,
1111                   const std::vector<Output>& grad_inputs,
1112                   std::vector<Output>* grad_outputs) {
1113   if (op.num_inputs() != 2) {
1114     return errors::InvalidArgument("Cumsum requires 2 arguments");
1115   }
1116   if (grad_inputs.size() != 1) {
1117     return errors::InvalidArgument("Cumsum grad requires 1 grad input");
1118   }
1119 
1120   Cumsum::Attrs attrs;
1121   TF_RETURN_IF_ERROR(
1122       GetNodeAttr(op.node()->attrs(), "exclusive", &attrs.exclusive_));
1123   bool reverse;
1124   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "reverse", &reverse));
1125   attrs.reverse_ = !reverse;
1126 
1127   auto axis = op.input(1);
1128   auto sum = Cumsum(scope, grad_inputs[0], axis, attrs);
1129   grad_outputs->push_back(sum.out);
1130   grad_outputs->push_back(NoGradient());
1131   return scope.status();
1132 }
1133 REGISTER_GRADIENT_OP("Cumsum", CumsumGrad);
1134 
IsFloatingPointDtype(DataType dtype)1135 bool IsFloatingPointDtype(DataType dtype) {
1136   static constexpr DataType valid_dtypes[] = {
1137       DT_FLOAT, DT_HALF, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128};
1138   return std::find(std::begin(valid_dtypes), std::end(valid_dtypes), dtype) !=
1139          std::end(valid_dtypes);
1140 }
1141 
CastGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1142 Status CastGrad(const Scope& scope, const Operation& op,
1143                 const std::vector<Output>& grad_inputs,
1144                 std::vector<Output>* grad_outputs) {
1145   if (op.num_inputs() != 1) {
1146     return errors::InvalidArgument("Cast requires 2 arguments");
1147   }
1148   if (grad_inputs.size() != 1) {
1149     return errors::InvalidArgument("Cast grad requires 1 grad input");
1150   }
1151 
1152   auto src_type = op.input_type(0);
1153   auto dst_type = grad_inputs[0].type();
1154   if (IsFloatingPointDtype(src_type) && IsFloatingPointDtype(dst_type)) {
1155     grad_outputs->push_back(Cast(scope, grad_inputs[0], src_type));
1156   } else {
1157     grad_outputs->push_back(NoGradient());
1158   }
1159   return scope.status();
1160 }
1161 REGISTER_GRADIENT_OP("Cast", CastGrad);
1162 
SelectGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1163 Status SelectGrad(const Scope& scope, const Operation& op,
1164                   const std::vector<Output>& grad_inputs,
1165                   std::vector<Output>* grad_outputs) {
1166   if (op.num_inputs() != 3) {
1167     return errors::InvalidArgument("Select requires 3 arguments");
1168   }
1169   if (grad_inputs.size() != 1) {
1170     return errors::InvalidArgument("Select grad requires 1 grad input");
1171   }
1172 
1173   auto c = op.input(0);
1174   auto zeros = ZerosLike(scope, grad_inputs[0]);
1175   grad_outputs->push_back(NoGradient());  // Condition
1176   grad_outputs->push_back(Where3(scope, c, grad_inputs[0], zeros));
1177   grad_outputs->push_back(Where3(scope, c, zeros, grad_inputs[0]));
1178   return scope.status();
1179 }
1180 REGISTER_GRADIENT_OP("Select", SelectGrad);
1181 
SelectV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1182 Status SelectV2Grad(const Scope& scope, const Operation& op,
1183                     const std::vector<Output>& grad_inputs,
1184                     std::vector<Output>* grad_outputs) {
1185   if (op.num_inputs() != 3) {
1186     return errors::InvalidArgument("Select requires 3 arguments");
1187   }
1188 
1189   if (grad_inputs.size() != 1) {
1190     return errors::InvalidArgument("Select grad requires 1 grad input");
1191   }
1192 
1193   auto c = op.input(0);
1194   auto x = op.input(1);
1195   auto y = op.input(2);
1196 
1197   auto zeros = ZerosLike(scope, grad_inputs[0]);
1198   auto gx = SelectV2(scope, c, grad_inputs[0], zeros);
1199   auto x_shape = Shape(scope, x);
1200   auto output_shape = Shape(scope, op.output(0));
1201 
1202   // Reduce away broadcasted leading dims.
1203   auto reduce_x = internal::BroadcastGradientArgs(scope, x_shape, output_shape);
1204   auto gx_sum =
1205       ReduceSum(scope, gx, /*axis=*/reduce_x.r0, ReduceSum::KeepDims(true));
1206   auto gx_sum_reshape = Reshape(scope, gx_sum, x_shape);
1207 
1208   auto gy = SelectV2(scope, c, zeros, grad_inputs[0]);
1209   auto y_shape = Shape(scope, y);
1210 
1211   // Reduce away broadcasted leading dims.
1212   auto reduce_y = internal::BroadcastGradientArgs(scope, y_shape, output_shape);
1213   auto gy_sum =
1214       ReduceSum(scope, gy, /*axis=*/reduce_y.r0, ReduceSum::KeepDims(true));
1215   auto gy_sum_reshape = Reshape(scope, gy_sum, y_shape);
1216 
1217   grad_outputs->push_back(NoGradient());  // Condition
1218   grad_outputs->push_back(gx_sum_reshape);
1219   grad_outputs->push_back(gy_sum_reshape);
1220   return scope.status();
1221 }
1222 
1223 REGISTER_GRADIENT_OP("SelectV2", SelectV2Grad);
1224 
1225 // Helper function for unsorted segment ops.
1226 // Returns 'ids' with negative elements replaced by 0.
GetZeroClippedIndices(const Scope & scope,const Output & ids)1227 Output GetZeroClippedIndices(const Scope& scope, const Output& ids) {
1228   return Maximum(scope, ids, ZerosLike(scope, ids));
1229 }
1230 
1231 // Helper function for unsorted segment ops.
1232 // Returns a mask of where 'ids' are positive, reshaped so that it will be
1233 // broadcastable to the result shape of gathering params by ids.
GetIsPositive(const Scope & scope,const Output & params,const Output & ids)1234 Output GetIsPositive(const Scope& scope, const Output& params,
1235                      const Output& ids) {
1236   Output is_positive = GreaterEqual(scope, ids, ZerosLike(scope, ids));
1237   // tf.where(condition, x, y) requires condition to have the same shape as x
1238   // and y.
1239   Output is_positive_shape = Shape(scope, is_positive);
1240   Output ones =
1241       Tile(scope, Const(scope, {1}), Subtract(scope, Rank(scope, params), {1}));
1242   auto broadcastable_shape = Concat(scope, {is_positive_shape, ones},
1243                                     /*axis=*/0);
1244   is_positive = Reshape(scope, is_positive, broadcastable_shape);
1245   is_positive = LogicalAnd(scope, is_positive, OnesLike(scope, is_positive));
1246   return is_positive;
1247 }
1248 
1249 // Helper function for unsorted segment ops.
1250 // Gathers params for positive segment ids and gathers 0 for inputs with
1251 // negative segment id.
GatherDropNegatives(const Scope & scope,const Output & params,Output & zero_clipped_indices,Output & is_positive)1252 Output GatherDropNegatives(const Scope& scope, const Output& params,
1253                            Output& zero_clipped_indices, Output& is_positive) {
1254   auto gathered = Gather(scope, params, zero_clipped_indices);
1255   // Replace gathered params of negative indices with 0.
1256   auto zero_slice = ZerosLike(scope, gathered);
1257   return SelectV2(scope, is_positive, gathered, zero_slice);
1258 }
1259 
UnsortedSegmentMinOrMaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1260 Status UnsortedSegmentMinOrMaxGrad(const Scope& scope, const Operation& op,
1261                                    const std::vector<Output>& grad_inputs,
1262                                    std::vector<Output>* grad_outputs) {
1263   if (op.num_inputs() != 3) {
1264     return errors::InvalidArgument("UnsortedSegmentMax requires 3 arguments");
1265   }
1266 
1267   if (grad_inputs.size() != 1) {
1268     return errors::InvalidArgument(
1269         "UnsortedSegmentMax grad requires 1 grad input");
1270   }
1271 
1272   auto grad = grad_inputs[0];
1273   // Get the number of selected (minimum or maximum) elements in each segment.
1274   auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1));
1275   auto is_positive = GetIsPositive(scope, op.output(0), op.input(1));
1276   Output gathered_outputs = GatherDropNegatives(
1277       scope, op.output(0), zero_clipped_indices, is_positive);
1278   Output is_selected = Equal(scope, op.input(0), gathered_outputs);
1279   is_selected = LogicalAnd(scope, is_selected, is_positive);
1280   auto num_selected = UnsortedSegmentSum(
1281       scope, Cast(scope, is_selected, grad.type()), op.input(1), op.input(2));
1282   // Compute the gradient for each segment.The gradient for the ith segment is
1283   // divided evenly among the selected elements in that segment.
1284   auto weighted_grads = Div(scope, grad, num_selected);
1285   auto gathered_grads = GatherDropNegatives(scope, weighted_grads,
1286                                             zero_clipped_indices, is_positive);
1287   auto zeros = ZerosLike(scope, gathered_grads);
1288   grad_outputs->push_back(SelectV2(scope, is_selected, gathered_grads, zeros));
1289   grad_outputs->push_back(NoGradient());
1290   grad_outputs->push_back(NoGradient());
1291   return scope.status();
1292 }
1293 
1294 REGISTER_GRADIENT_OP("UnsortedSegmentMax", UnsortedSegmentMinOrMaxGrad);
1295 REGISTER_GRADIENT_OP("UnsortedSegmentMin", UnsortedSegmentMinOrMaxGrad);
1296 
UnsortedSegmentSumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1297 Status UnsortedSegmentSumGrad(const Scope& scope, const Operation& op,
1298                               const std::vector<Output>& grad_inputs,
1299                               std::vector<Output>* grad_outputs) {
1300   if (op.num_inputs() != 3) {
1301     return errors::InvalidArgument("UnsortedSegmentSum requires 3 arguments");
1302   }
1303 
1304   if (grad_inputs.size() != 1) {
1305     return errors::InvalidArgument(
1306         "UnsortedSegmentSum grad requires 1 grad input");
1307   }
1308 
1309   auto zero_clipped_indices = GetZeroClippedIndices(scope, op.input(1));
1310   auto is_positive = GetIsPositive(scope, grad_inputs[0], op.input(1));
1311   grad_outputs->push_back(GatherDropNegatives(
1312       scope, grad_inputs[0], zero_clipped_indices, is_positive));
1313   grad_outputs->push_back(NoGradient());
1314   grad_outputs->push_back(NoGradient());
1315   return scope.status();
1316 }
1317 
1318 REGISTER_GRADIENT_OP("UnsortedSegmentSum", UnsortedSegmentSumGrad);
1319 
1320 }  // anonymous namespace
1321 }  // namespace ops
1322 }  // namespace tensorflow
1323