• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define _USE_MATH_DEFINES
17 #include <cmath>
18 
19 #include "tensorflow/cc/ops/array_ops_internal.h"
20 #include "tensorflow/cc/ops/math_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 
23 #include "tensorflow/cc/framework/grad_op_registry.h"
24 #include "tensorflow/cc/framework/gradients.h"
25 
26 namespace tensorflow {
27 namespace ops {
28 namespace {
29 
30 // Logical operations have no gradients.
31 REGISTER_NO_GRADIENT_OP("Less");
32 REGISTER_NO_GRADIENT_OP("LessEqual");
33 REGISTER_NO_GRADIENT_OP("Greater");
34 REGISTER_NO_GRADIENT_OP("GreaterEqual");
35 REGISTER_NO_GRADIENT_OP("Equal");
36 REGISTER_NO_GRADIENT_OP("ApproximateEqual");
37 REGISTER_NO_GRADIENT_OP("NotEqual");
38 REGISTER_NO_GRADIENT_OP("LogicalAnd");
39 REGISTER_NO_GRADIENT_OP("LogicalOr");
40 REGISTER_NO_GRADIENT_OP("LogicalNot");
41 REGISTER_NO_GRADIENT_OP("Floor");
42 
43 // Conjugate helper function returns the conjugate of an Output if it
44 // is complex valued.
ConjugateHelper(const Scope & scope,const Output & out)45 Output ConjugateHelper(const Scope& scope, const Output& out) {
46   DataType dtype = out.type();
47   if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
48     return Conj(scope, out);
49   } else {
50     return out;
51   }
52 }
53 
54 // TODO(andydavis) Add control dependencies to gradient functions (as needed).
55 
AbsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)56 Status AbsGrad(const Scope& scope, const Operation& op,
57                const std::vector<Output>& grad_inputs,
58                std::vector<Output>* grad_outputs) {
59   // dx = dy * sign(x)
60   grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
61   return scope.status();
62 }
63 REGISTER_GRADIENT_OP("Abs", AbsGrad);
64 
NegGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)65 Status NegGrad(const Scope& scope, const Operation& op,
66                const std::vector<Output>& grad_inputs,
67                std::vector<Output>* grad_outputs) {
68   // dx = -dy;
69   grad_outputs->push_back(Neg(scope, grad_inputs[0]));
70   return scope.status();
71 }
72 REGISTER_GRADIENT_OP("Neg", NegGrad);
73 
InvGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)74 Status InvGrad(const Scope& scope, const Operation& op,
75                const std::vector<Output>& grad_inputs,
76                std::vector<Output>* grad_outputs) {
77   // Use the built-in operator.
78   grad_outputs->push_back(
79       internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
80   return scope.status();
81 }
82 REGISTER_GRADIENT_OP("Inv", InvGrad);
83 REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
84 
SquareGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)85 Status SquareGrad(const Scope& scope, const Operation& op,
86                   const std::vector<Output>& grad_inputs,
87                   std::vector<Output>* grad_outputs) {
88   // dy/dx = (2 * x)
89   auto two = Cast(scope, Const(scope, 2), op.input(0).type());
90   auto dydx = Mul(scope, two, op.input(0));
91   // grad(x) = grad(y) * conj(dy/dx)
92   grad_outputs->push_back(
93       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
94   return scope.status();
95 }
96 REGISTER_GRADIENT_OP("Square", SquareGrad);
97 
SqrtGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)98 Status SqrtGrad(const Scope& scope, const Operation& op,
99                 const std::vector<Output>& grad_inputs,
100                 std::vector<Output>* grad_outputs) {
101   // Use the built-in operator.
102   grad_outputs->push_back(
103       internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
104   return scope.status();
105 }
106 REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
107 
RsqrtGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)108 Status RsqrtGrad(const Scope& scope, const Operation& op,
109                  const std::vector<Output>& grad_inputs,
110                  std::vector<Output>* grad_outputs) {
111   // Use the built-in operator.
112   grad_outputs->push_back(
113       internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
114   return scope.status();
115 }
116 REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
117 
ExpGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)118 Status ExpGrad(const Scope& scope, const Operation& op,
119                const std::vector<Output>& grad_inputs,
120                std::vector<Output>* grad_outputs) {
121   // dy/dx = exp(x) = y
122   // grad(x) = grad(y) * conj(dy/dx)
123   //         = grad(y) * conj(y)
124   grad_outputs->push_back(
125       Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
126   return scope.status();
127 }
128 REGISTER_GRADIENT_OP("Exp", ExpGrad);
129 
Expm1Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)130 Status Expm1Grad(const Scope& scope, const Operation& op,
131                  const std::vector<Output>& grad_inputs,
132                  std::vector<Output>* grad_outputs) {
133   // y = expm1(x)
134   // dy/dx = exp(x)
135   auto dydx = Exp(scope, op.input(0));
136   // grad(x) = grad(y) * conj(dy/dx)
137   grad_outputs->push_back(
138       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
139   return scope.status();
140 }
141 REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
142 
LogGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)143 Status LogGrad(const Scope& scope, const Operation& op,
144                const std::vector<Output>& grad_inputs,
145                std::vector<Output>* grad_outputs) {
146   // y = log(x)
147   // dy/dx = 1 / x
148   auto dydx = Reciprocal(scope, op.input(0));
149   // grad(x) = grad(y) * conj(dy/dx)
150   grad_outputs->push_back(
151       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
152   return scope.status();
153 }
154 REGISTER_GRADIENT_OP("Log", LogGrad);
155 
Log1pGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)156 Status Log1pGrad(const Scope& scope, const Operation& op,
157                  const std::vector<Output>& grad_inputs,
158                  std::vector<Output>* grad_outputs) {
159   // y = log1p(x)
160   // dy/dx = 1 / (1 + x)
161   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
162   auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
163   // grad(x) = grad(y) * conj(dy/dx)
164   grad_outputs->push_back(
165       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
166   return scope.status();
167 }
168 REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
169 
SinhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)170 Status SinhGrad(const Scope& scope, const Operation& op,
171                 const std::vector<Output>& grad_inputs,
172                 std::vector<Output>* grad_outputs) {
173   // y = sinh(x)
174   // dy/dx = cosh(x)
175   auto dydx = Cosh(scope, op.input(0));
176   // grad(x) = grad(y) * conj(dy/dx)
177   grad_outputs->push_back(
178       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
179   return scope.status();
180 }
181 REGISTER_GRADIENT_OP("Sinh", SinhGrad);
182 
CoshGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)183 Status CoshGrad(const Scope& scope, const Operation& op,
184                 const std::vector<Output>& grad_inputs,
185                 std::vector<Output>* grad_outputs) {
186   // y = cosh(x)
187   // dy/dx = sinh(x)
188   auto dydx = Sinh(scope, op.input(0));
189   // grad(x) = grad(y) * conj(dy/dx)
190   grad_outputs->push_back(
191       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
192   return scope.status();
193 }
194 REGISTER_GRADIENT_OP("Cosh", CoshGrad);
195 
TanhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)196 Status TanhGrad(const Scope& scope, const Operation& op,
197                 const std::vector<Output>& grad_inputs,
198                 std::vector<Output>* grad_outputs) {
199   // Use the built-in operator.
200   // Note that the built-in operator does not return the conjugate of
201   // the gradient.
202   auto grad = grad_inputs[0];
203   // Optimization to avoid calculating conj(y) until the gradient is
204   // evaluated.
205   Scope grad_scope = scope.WithControlDependencies(grad);
206   auto y = ConjugateHelper(grad_scope, op.output(0));
207   grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
208   return grad_scope.status();
209 }
210 REGISTER_GRADIENT_OP("Tanh", TanhGrad);
211 
AsinhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)212 Status AsinhGrad(const Scope& scope, const Operation& op,
213                  const std::vector<Output>& grad_inputs,
214                  std::vector<Output>* grad_outputs) {
215   // y = asinh(x)
216   // dy/dx = 1 / cosh(y)
217   auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
218   // grad(x) = grad(y) * conj(dy/dx)
219   grad_outputs->push_back(
220       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
221   return scope.status();
222 }
223 REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
224 
AcoshGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)225 Status AcoshGrad(const Scope& scope, const Operation& op,
226                  const std::vector<Output>& grad_inputs,
227                  std::vector<Output>* grad_outputs) {
228   // y = acosh(x)
229   // dy/dx = 1 / sinh(y)
230   auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
231   // grad(x) = grad(y) * conj(dy/dx)
232   grad_outputs->push_back(
233       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
234   return scope.status();
235 }
236 REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
237 
AtanhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)238 Status AtanhGrad(const Scope& scope, const Operation& op,
239                  const std::vector<Output>& grad_inputs,
240                  std::vector<Output>* grad_outputs) {
241   // y = atanh(x)
242   // dy/dx = 1 / (1 - x^2)
243   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
244   auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
245   // grad(x) = grad(y) * conj(dy/dx)
246   grad_outputs->push_back(
247       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
248   return scope.status();
249 }
250 REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
251 
SigmoidGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)252 Status SigmoidGrad(const Scope& scope, const Operation& op,
253                    const std::vector<Output>& grad_inputs,
254                    std::vector<Output>* grad_outputs) {
255   // Use the built-in operator.
256   // Note that the built-in operator does not return the conjugate of
257   // the gradient.
258   auto grad = grad_inputs[0];
259   // Optimization to avoid calculating conj(y) until the gradient is
260   // evaluated.
261   Scope grad_scope = scope.WithControlDependencies(grad);
262   auto y = ConjugateHelper(grad_scope, op.output(0));
263   grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
264   return grad_scope.status();
265 }
266 REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
267 
SignGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)268 Status SignGrad(const Scope& scope, const Operation& op,
269                 const std::vector<Output>& grad_inputs,
270                 std::vector<Output>* grad_outputs) {
271   auto shape = Shape(scope, op.input(0));
272   auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
273   auto dx = Fill(scope, shape, zero);
274   grad_outputs->push_back(dx);
275   return scope.status();
276 }
277 REGISTER_GRADIENT_OP("Sign", SignGrad);
278 
SinGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)279 Status SinGrad(const Scope& scope, const Operation& op,
280                const std::vector<Output>& grad_inputs,
281                std::vector<Output>* grad_outputs) {
282   // y = sin(x)
283   // dy/dx = cos(x)
284   auto dydx = Cos(scope, op.input(0));
285   // grad(x) = grad(y) * conj(dy/dx)
286   grad_outputs->push_back(
287       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
288   return scope.status();
289 }
290 REGISTER_GRADIENT_OP("Sin", SinGrad);
291 
CosGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)292 Status CosGrad(const Scope& scope, const Operation& op,
293                const std::vector<Output>& grad_inputs,
294                std::vector<Output>* grad_outputs) {
295   // y = cos(x)
296   // dy/dx = -sin(x)
297   auto dydx = Neg(scope, Sin(scope, op.input(0)));
298   // grad(x) = grad(y) * conj(dy/dx)
299   grad_outputs->push_back(
300       Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
301   return scope.status();
302 }
303 REGISTER_GRADIENT_OP("Cos", CosGrad);
304 
AsinGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)305 Status AsinGrad(const Scope& scope, const Operation& op,
306                 const std::vector<Output>& grad_inputs,
307                 std::vector<Output>* grad_outputs) {
308   // y = asin(x)
309   // dy/dx = 1 / sqrt(1 - x^2)
310   auto x2 = Square(scope, op.input(0));
311   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
312   auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
313   // grad(x) = grad(y) * conj(dy/dx)
314   auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
315   grad_outputs->push_back(dx);
316   return scope.status();
317 }
318 REGISTER_GRADIENT_OP("Asin", AsinGrad);
319 
AcosGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)320 Status AcosGrad(const Scope& scope, const Operation& op,
321                 const std::vector<Output>& grad_inputs,
322                 std::vector<Output>* grad_outputs) {
323   // y = acos(x)
324   // dy/dx = - 1 / (1 - x * x)^1/2
325   // dx = dy * (- 1 / (1 - x * x)^1/2)
326   auto x2 = Square(scope, op.input(0));
327   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
328   auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
329   auto dx = Mul(scope, grad_inputs[0], dydx);
330   grad_outputs->push_back(dx);
331   return scope.status();
332 }
333 REGISTER_GRADIENT_OP("Acos", AcosGrad);
334 
TanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)335 Status TanGrad(const Scope& scope, const Operation& op,
336                const std::vector<Output>& grad_inputs,
337                std::vector<Output>* grad_outputs) {
338   // y = tan(x)
339   // dy/dx = sec(x)^2 = 1 / cos(x)^2
340   auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
341   // grad(x) = grad(y) * conj(dy/dx)
342   auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
343   grad_outputs->push_back(dx);
344   return scope.status();
345 }
346 REGISTER_GRADIENT_OP("Tan", TanGrad);
347 
AtanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)348 Status AtanGrad(const Scope& scope, const Operation& op,
349                 const std::vector<Output>& grad_inputs,
350                 std::vector<Output>* grad_outputs) {
351   // y = arctan(x)
352   // dy/dx = 1 / (1 + x^2)
353   // dx = dy * (1 / (1 + x^2)
354   auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
355   auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
356   auto dx = Mul(scope, grad_inputs[0], dydx);
357   grad_outputs->push_back(dx);
358   return scope.status();
359 }
360 REGISTER_GRADIENT_OP("Atan", AtanGrad);
361 
362 // BinaryGradCommon handles the setup for binary ops that broadcast
363 // their inputs.
BinaryGradCommon(const Scope & scope,const Operation & op,std::vector<Output> * grad_outputs,const Output & gx_1,const Output & gx_2)364 Status BinaryGradCommon(const Scope& scope, const Operation& op,
365                         std::vector<Output>* grad_outputs, const Output& gx_1,
366                         const Output& gx_2) {
367   auto sx_1 = Shape(scope, op.input(0));
368   auto sx_2 = Shape(scope, op.input(1));
369   auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
370   auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
371   auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
372   grad_outputs->push_back(dx_1);
373   grad_outputs->push_back(dx_2);
374   return scope.status();
375 }
376 
AddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)377 Status AddGrad(const Scope& scope, const Operation& op,
378                const std::vector<Output>& grad_inputs,
379                std::vector<Output>* grad_outputs) {
380   // y = x_1 + x_2
381   // dy/dx_1 = dy/dx_2 = 1
382   auto gx_1 = Identity(scope, grad_inputs[0]);
383   auto gx_2 = Identity(scope, grad_inputs[0]);
384   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
385 }
386 REGISTER_GRADIENT_OP("Add", AddGrad);
387 
SubGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)388 Status SubGrad(const Scope& scope, const Operation& op,
389                const std::vector<Output>& grad_inputs,
390                std::vector<Output>* grad_outputs) {
391   // y = x_1 - x_2
392   // dy/dx_1 = 1
393   // dy/dx_2 = -1
394   auto gx_1 = Identity(scope, grad_inputs[0]);
395   auto gx_2 = Neg(scope, grad_inputs[0]);
396   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
397 }
398 REGISTER_GRADIENT_OP("Sub", SubGrad);
399 
MulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)400 Status MulGrad(const Scope& scope, const Operation& op,
401                const std::vector<Output>& grad_inputs,
402                std::vector<Output>* grad_outputs) {
403   auto x_1 = ConjugateHelper(scope, op.input(0));
404   auto x_2 = ConjugateHelper(scope, op.input(1));
405   // y = x_1 * x_2
406   // dy/dx_1 = x_2
407   // dy/dx_2 = x_1
408   auto gx_1 = Mul(scope, grad_inputs[0], x_2);
409   auto gx_2 = Mul(scope, grad_inputs[0], x_1);
410   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
411 }
412 REGISTER_GRADIENT_OP("Mul", MulGrad);
413 
DivGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)414 Status DivGrad(const Scope& scope, const Operation& op,
415                const std::vector<Output>& grad_inputs,
416                std::vector<Output>* grad_outputs) {
417   auto x_1 = ConjugateHelper(scope, op.input(0));
418   auto x_2 = ConjugateHelper(scope, op.input(1));
419   // y = x_1 / x_2
420   // dy/dx_1 = 1/x_2
421   // dy/dx_2 = -x_1/x_2^2
422   auto gx_1 = Div(scope, grad_inputs[0], x_2);
423   auto gx_2 = Mul(scope, grad_inputs[0],
424                   Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
425   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
426 }
427 REGISTER_GRADIENT_OP("Div", DivGrad);
428 
RealDivGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)429 Status RealDivGrad(const Scope& scope, const Operation& op,
430                    const std::vector<Output>& grad_inputs,
431                    std::vector<Output>* grad_outputs) {
432   auto x_1 = ConjugateHelper(scope, op.input(0));
433   auto x_2 = ConjugateHelper(scope, op.input(1));
434   // y = x_1 / x_2
435   // dy/dx_1 = 1/x_2
436   // dy/dx_2 = -x_1/x_2^2
437   auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
438   auto gx_2 = Mul(scope, grad_inputs[0],
439                   RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
440   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
441 }
442 REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
443 
DivNoNanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)444 Status DivNoNanGrad(const Scope& scope, const Operation& op,
445                     const std::vector<Output>& grad_inputs,
446                     std::vector<Output>* grad_outputs) {
447   auto x_1 = ConjugateHelper(scope, op.input(0));
448   auto x_2 = ConjugateHelper(scope, op.input(1));
449   // y = x_1 / x_2
450   // dy/dx_1 = 1/x_2
451   // dy/dx_2 = -x_1/x_2^2
452   auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
453   auto gx_2 = Mul(scope, grad_inputs[0],
454                   DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
455   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
456 }
457 REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
458 
SquaredDifferenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)459 Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
460                              const std::vector<Output>& grad_inputs,
461                              std::vector<Output>* grad_outputs) {
462   auto x_1 = ConjugateHelper(scope, op.input(0));
463   auto x_2 = ConjugateHelper(scope, op.input(1));
464   // y = (x_1 - x_2)^2
465   // dy/dx_1 = 2 * (x_1 - x_2)
466   // dy/dx_2 = -2 * (x_1 - x_2)
467   auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
468   auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
469   auto gx_2 = Neg(scope, gx_1);
470   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
471 }
472 REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
473 
AddNGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)474 Status AddNGrad(const Scope& scope, const Operation& op,
475                 const std::vector<Output>& grad_inputs,
476                 std::vector<Output>* grad_outputs) {
477   // AddN doesn't support broadcasting, so all the inputs must be the
478   // same shape.
479   // Note:
480   // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
481   // hence dx_k = dy for all x_k
482   // So the gradient for AddN just transfers the incoming gradient to
483   // all outgoing gradients.
484   auto incoming = Identity(scope, grad_inputs[0]);
485   for (int32 i = 0; i < op.num_inputs(); ++i) {
486     grad_outputs->push_back(incoming);
487   }
488   return scope.status();
489 }
490 REGISTER_GRADIENT_OP("AddN", AddNGrad);
491 
PowGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)492 Status PowGrad(const Scope& scope, const Operation& op,
493                const std::vector<Output>& grad_inputs,
494                std::vector<Output>* grad_outputs) {
495   auto x = ConjugateHelper(scope, op.input(0));
496   auto y = ConjugateHelper(scope, op.input(1));
497   auto z = ConjugateHelper(scope, op.output(0));
498   auto grad = grad_inputs[0];
499   // grad * y * pow(x, y - 1)
500   auto one = Cast(scope, Const(scope, 1.0), y.type());
501   auto gx_1 = Mul(scope,
502                   Mul(scope, grad, y),
503                   Pow(scope, x, Sub(scope, y, one)));
504   // Avoid false singularity at x = 0
505   DataType x_dtype = x.type();
506   auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
507   if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
508     // real(x) < 0 is fine for the complex case
509     auto log_x = Where3(scope,
510                         NotEqual(scope, x, zero),
511                         Log(scope, x),
512                         ZerosLike(scope, x));
513     auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
514     return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
515   } else {
516     // There's no sensible real value to return if x < 0, so return 0
517     auto log_x = Where3(scope,
518                         Greater(scope, x, zero),
519                         Log(scope, x),
520                         ZerosLike(scope, x));
521     auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
522     return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
523   }
524 }
525 REGISTER_GRADIENT_OP("Pow", PowGrad);
526 
527 // MaximumMinimumGradCommon adds shared ops to calculate gradients for
528 // the binary Maximum and Minimum ops.
MaximumMinimumGradCommon(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs,const Output & comparator)529 Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
530                                 const std::vector<Output>& grad_inputs,
531                                 std::vector<Output>* grad_outputs,
532                                 const Output& comparator) {
533   // comparator is a boolean tensor, with
534   // y = x_1 at points where comparator is true, and x_2 otherwise
535   // Therefore
536   // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
537   // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
538   auto grad = grad_inputs[0];
539   auto zeros = ZerosLike(scope, grad);
540   auto gx_1 = Where3(scope, comparator, grad, zeros);
541   auto gx_2 = Where3(scope, comparator, zeros, grad);
542   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
543 }
544 
MaximumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)545 Status MaximumGrad(const Scope& scope, const Operation& op,
546                    const std::vector<Output>& grad_inputs,
547                    std::vector<Output>* grad_outputs) {
548   auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
549   return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
550                                   comparator);
551 }
552 REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
553 
MinimumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)554 Status MinimumGrad(const Scope& scope, const Operation& op,
555                    const std::vector<Output>& grad_inputs,
556                    std::vector<Output>* grad_outputs) {
557   auto comparator = LessEqual(scope, op.input(0), op.input(1));
558   return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
559                                   comparator);
560 }
561 REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
562 
RealGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)563 Status RealGrad(const Scope& scope, const Operation& op,
564                 const std::vector<Output>& grad_inputs,
565                 std::vector<Output>* grad_outputs) {
566   auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
567   auto dx = Complex(scope, grad_inputs[0], zero);
568   grad_outputs->push_back(dx);
569   return scope.status();
570 }
571 REGISTER_GRADIENT_OP("Real", RealGrad);
572 
ImagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)573 Status ImagGrad(const Scope& scope, const Operation& op,
574                 const std::vector<Output>& grad_inputs,
575                 std::vector<Output>* grad_outputs) {
576   auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
577   auto dx = Complex(scope, zero, grad_inputs[0]);
578   grad_outputs->push_back(dx);
579   return scope.status();
580 }
581 REGISTER_GRADIENT_OP("Imag", ImagGrad);
582 
ComplexGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)583 Status ComplexGrad(const Scope& scope, const Operation& op,
584                    const std::vector<Output>& grad_inputs,
585                    std::vector<Output>* grad_outputs) {
586   auto gx_1 = Real(scope, grad_inputs[0]);
587   auto gx_2 = Imag(scope, grad_inputs[0]);
588   return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
589 }
590 REGISTER_GRADIENT_OP("Complex", ComplexGrad);
591 
AngleGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)592 Status AngleGrad(const Scope& scope, const Operation& op,
593                  const std::vector<Output>& grad_inputs,
594                  std::vector<Output>* grad_outputs) {
595   // y = Angle(x)
596   // dx = -dy / (Im(x) + iRe(x)) = -dy * z
597   auto re = Real(scope, op.input(0));
598   auto im = Imag(scope, op.input(0));
599   auto z_inv = Reciprocal(scope, Complex(scope, im, re));
600   auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
601   auto grad = Complex(scope, grad_inputs[0], zero);
602   auto dx = Neg(scope, Mul(scope, grad, z_inv));
603   grad_outputs->push_back(dx);
604   return scope.status();
605 }
606 REGISTER_GRADIENT_OP("Angle", AngleGrad);
607 
ConjGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)608 Status ConjGrad(const Scope& scope, const Operation& op,
609                 const std::vector<Output>& grad_inputs,
610                 std::vector<Output>* grad_outputs) {
611   grad_outputs->push_back(Conj(scope, grad_inputs[0]));
612   return scope.status();
613 }
614 REGISTER_GRADIENT_OP("Conj", ConjGrad);
615 
616 // Integer division x / y, assuming x and y >=0, but treats x/0 = x
SafeDivHelper(const Scope & scope,const Output & x,const Output & y)617 Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
618   return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
619 }
620 
621 // Helper function for reduction ops.
622 //
623 // input_shape: 1-D Tensor, the shape of the Tensor being reduced.
624 // axes: 1-D Tensor, the reduction axes.
625 //   Note that the reduction indices are in the range
626 //   -rank(input_shape), rank(input_shape)
627 // returns a 1-D Tensor, the output shape as if keep_dims were set to True.
ReducedShapeHelper(const Scope & scope,const Output & input_shape,const Output & reduction_axes)628 Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
629                           const Output& reduction_axes) {
630   auto zero = Const(scope, 0);
631   auto one = Const(scope, 1);
632 
633   // Running example in comments
634   // input_shape = [2, 3, 5, 7]
635   // axes = [1, 2]
636   // The result (a shape after a reduction with keep_dims=True)
637   // [2, 1, 1, 7]
638   //
639   // We can treat each entry in axes as an index into input_shape that
640   // should be replaced by 1.
641   // We use DynamicStitch to do this.
642 
643   // input_rank = 4
644   auto input_rank = Size(scope, input_shape);
645 
646   // Normalize any negative indices in the reduction_axes to positive
647   // values.
648   auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
649 
650   // This [0..input_rank) range of integers is used in DynamicStitch to
651   // first copy input_shape to the result.
652   // input_rank_range = [0, 1, 2, 3]
653   auto input_rank_range = Range(scope, zero, input_rank, one);
654 
655   // A 1-filled tensor with the same shape as axes. DynamicStitch will
656   // merge these 1s (using axes for indices) to the correct
657   // position in the result.
658   // axes_ones = [1, 1]
659   auto axes_ones = OnesLike(scope, axes);
660 
661   // using DynamicStitch:
662   // indices = { input_rank_range, axes }
663   //         = { [0, 1, 2, 3], [1, 2] }
664   // data = { input_shape, axes_ones }
665   //      = { [2, 3, 5, 7], [1, 1] }
666   // The input_rank_range entry in indices first replicates the
667   // input_shape to the result.
668   // The axes entry in indices then moves a 1 to each of its entries,
669   // resulting in
670   // [2, 1, 1, 7]
671   std::vector<Output> indices = {input_rank_range, axes};
672   std::vector<Output> data = {input_shape, axes_ones};
673   return DynamicStitch(scope, indices, data);
674 }
675 
676 // SumGradHelper returns the gradient for the Sum operator, and is used
677 // by SumGrad and MeanGrad.
SumGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs)678 Output SumGradHelper(const Scope& scope, const Operation& op,
679                      const std::vector<Output>& grad_inputs) {
680   // The partial derivative for any input along a "reduced" dimension
681   // is just 1, so we only need replicate the output gradient on such a
682   // dimension to its "expanded" shape.
683   // Running example:
684   // input is
685   // [[a, b, c],
686   //  [d, e, f]]
687   // reduction_indices = [1]
688   // Sum = [a + b + c, d + e + f]
689   // if the gradient is [g1, g2]
690   // We want the propagated gradient to be
691   // [[g1, g1, g1],
692   //  [g2, g2, g2]]
693 
694   // input_shape = [2, 3]
695   auto input_shape = Shape(scope, op.input(0));
696 
697   // output_shape_kept_dims = [2, 1]
698   auto output_shape_kept_dims =
699       ReducedShapeHelper(scope, input_shape, op.input(1));
700 
701   // This step "flips" any 1s with values from the input_shape, and
702   // replaces remaining entries with 1. This creates a shape that
703   // shows how much each dimension in the incoming gradient should be
704   // replicated.
705   // tile_scaling = [1, 3]
706   auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
707 
708   // grad = [[g1], [g2]]
709   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
710 
711   // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
712   return Tile(scope, grad, tile_scaling);
713 }
714 
SumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)715 Status SumGrad(const Scope& scope, const Operation& op,
716                const std::vector<Output>& grad_inputs,
717                std::vector<Output>* grad_outputs) {
718   grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
719 
720   // Stop propagation along reduction_indices
721   grad_outputs->push_back(NoGradient());
722   return scope.status();
723 }
724 REGISTER_GRADIENT_OP("Sum", SumGrad);
725 
MeanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)726 Status MeanGrad(const Scope& scope, const Operation& op,
727                 const std::vector<Output>& grad_inputs,
728                 std::vector<Output>* grad_outputs) {
729   // The Mean gradient is just like the Sum gradient, except that
730   // all gradients are also divided by the size of reduced groups.
731   auto sum_grad = SumGradHelper(scope, op, grad_inputs);
732 
733   // The product of all entries in a tensor's shape is the total
734   // number of entries in the tensor. This step calculates
735   // n_input_entries/n_output_entries
736   // = group_size
737   auto input_shape = Shape(scope, op.input(0));
738   auto output_shape = Shape(scope, op.output(0));
739   auto zero = Const(scope, 0);
740   auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
741                                   Prod(scope, output_shape, zero));
742 
743   // propagate sum_grad/group_size
744   grad_outputs->push_back(
745       Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
746 
747   // Stop propagation along reduction_indices
748   grad_outputs->push_back(NoGradient());
749   return scope.status();
750 }
751 REGISTER_GRADIENT_OP("Mean", MeanGrad);
752 
ErfGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)753 Status ErfGrad(const Scope& scope, const Operation& op,
754                const std::vector<Output>& grad_inputs,
755                std::vector<Output>* grad_outputs) {
756   auto grad = grad_inputs[0];
757   auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)),
758                                grad.type());
759   Scope grad_scope = scope.WithControlDependencies(grad);
760   auto x = ConjugateHelper(grad_scope, op.input(0));
761   // grad * 2/sqrt(pi) * exp(-x**2)
762   auto dx = Mul(grad_scope,
763                 Mul(grad_scope, grad, two_over_root_pi),
764                 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
765   grad_outputs->push_back(dx);
766   return grad_scope.status();
767 }
768 REGISTER_GRADIENT_OP("Erf", ErfGrad);
769 
LgammaGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)770 Status LgammaGrad(const Scope& scope, const Operation& op,
771                   const std::vector<Output>& grad_inputs,
772                   std::vector<Output>* grad_outputs) {
773   auto grad = grad_inputs[0];
774   Scope grad_scope = scope.WithControlDependencies(grad);
775   auto x = ConjugateHelper(grad_scope, op.input(0));
776   auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
777   grad_outputs->push_back(dx);
778   return grad_scope.status();
779 }
780 REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
781 
MinOrMaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)782 Status MinOrMaxGrad(const Scope& scope, const Operation& op,
783                     const std::vector<Output>& grad_inputs,
784                     std::vector<Output>* grad_outputs) {
785   // The partial derivative for any input along a "reduced" dimension
786   // is 1 when it is the min (or max) and 0 everywhere else. So the
787   // gradient calculation is identical for both operators.
788   //
789   // There's a special case for propagating gradients when there are
790   // multiple minima (or maxima) - we choose to divide the gradient
791   // equally among all matching inputs.
792   //
793   // Please note this comment
794   // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
795   // for details.
796 
797   // Running example:
798   // input: [[5, 5, 5],
799   //         [1, 2, -3]]
800   // reduction_indices: [1]
801   auto input = op.input(0);
802   auto reduction_indices = op.input(1);
803 
804   // [2, 3]
805   auto input_shape = Shape(scope, input);
806 
807   // [2, 1]
808   auto output_shape_kept_dims =
809       ReducedShapeHelper(scope, input_shape, reduction_indices);
810 
811   // for op=min (say)
812   // output = [5, -3]
813   // y = [[5],
814   //      [-3]]
815   auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
816 
817   // reshape([g1, g2], [2, 1]) = [[g1],
818   //                              [g2]]
819   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
820 
821   // indicators = equal(y, input)
822   //  = equal([[5],   [[5, 5, 5],
823   //           [-3]],  [1, 2, -3]])
824   //  = [[1, 1, 1],
825   //     [0, 0, 1]]
826   auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
827 
828   // [[3],
829   //  [1]]
830   auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
831                               output_shape_kept_dims);
832 
833   // [[1/3, 1/3, 1/3],
834   //  [0, 0, 1]]
835   auto scale = Div(scope, indicators, num_selected);
836 
837   // [[g1/3, g1/3, g1/3],
838   //  [0, 0, g2]]
839   grad_outputs->push_back(Mul(scope, scale, grad));
840 
841   // Stop propagation along reduction_indices
842   grad_outputs->push_back(NoGradient());
843   return scope.status();
844 }
845 REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
846 REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
847 
ProdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)848 Status ProdGrad(const Scope& scope, const Operation& op,
849                 const std::vector<Output>& grad_inputs,
850                 std::vector<Output>* grad_outputs) {
851   auto zero = Const(scope, 0);
852   auto one = Const(scope, 1);
853 
854   // The gradient can be expressed by dividing the product by each entry of
855   // the input tensor. If our input is
856   // [
857   //  [3, 4],
858   //  [5, 6],
859   //  [7, 8]
860   // ]
861   // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
862   // The gradient will have the same shape as the input
863   //     [
864   //       [105/3, 192/4],
865   // dz *  [105/5, 192/6],
866   //       [105/7, 192/6]
867   //     ]
868   // If the input contains a zero, the division is impossible but
869   // if we take the calculation that gave the first gradient
870   // (3 * 5 * 6)/3 is equal to 5 * 6
871   // the trick will be to cumprod the elements on the axis without
872   // the element at the current position (3 in the example above).
873   // We will take as example:
874   // [
875   //   [
876   //     [3.0, 4.0],
877   //     [5.0, 6.0],
878   //     [7.0, 8.0]
879   //   ],
880   //   [
881   //     [3.0, 5.0],
882   //     [0.0, 6.0],
883   //     [5.0, 6.0]
884   //   ]
885   // ]
886 
887   // [2, 3, 2]
888   auto input_shape = Shape(scope, op.input(0));
889 
890   // The Reshape with -1 flattens the reduction indices.
891   // [1]
892   auto reduction_indices = Reshape(scope, op.input(1), {-1});
893 
894   // [2, 1, 2]
895   auto output_shape_kept_dims =
896       ReducedShapeHelper(scope, input_shape, reduction_indices);
897 
898   // [1, 3, 1]
899   auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
900 
901   // [[[105, 192]], [[0, 180]]]
902   auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
903 
904   // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
905   auto grad_tiled = Tile(scope, grad, tile_scaling);
906 
907   Scope cpu_scope = scope.WithDevice("/cpu:0");
908 
909   // [3]
910   auto rank = Rank(cpu_scope, op.input(0));
911 
912 
913   // Normalize any negative indices in the reduction_axes to positive values.
914   auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
915 
916   // [1]
917   auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
918 
919   // [0, 1, 2]
920   auto idx = Range(cpu_scope, zero, rank, one);
921 
922   // [0, 2]
923   auto other = SetDiff1D(cpu_scope, idx, reduced).out;
924 
925   // [1, 0, 2]
926   auto perm =
927       Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
928 
929   // 3 => [3]
930   auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
931 
932   // 2 * 2 => [2]
933   auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
934 
935   // [
936   //    [
937   //       [ 3.,  4.],
938   //       [ 3.,  5.]
939   //   ],
940   //   [
941   //       [ 5.,  6.],
942   //       [ 0.,  6.]
943   //   ],
944   //   [
945   //       [ 7.,  8.],
946   //       [ 5.,  6.]
947   //   ]
948   // ]
949   auto permuted = Transpose(scope, op.input(0), perm);
950 
951   // [3, 2, 2]
952   auto permuted_shape = Shape(scope, permuted);
953 
954   // [
955   //   [ 3.,  4.,  3.,  5.],
956   //   [ 5.,  6.,  0.,  6.],
957   //   [ 7.,  8.,  5.,  6.]
958   // ]
959   auto reshaped = Reshape(
960       scope, permuted,
961       Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
962 
963   // [
964   //   [ 1.,  1.,  1.,  1.],
965   //   [ 3.,  4.,  3.,  5.],
966   //   [ 15.,  24.,  0.,  30.]
967   // ]
968   auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
969 
970   // [
971   //   [ 35.,  48.,  0.,  36.],
972   //   [  7.,   8.,   5.,   6.],
973   //   [  1.,   1.,   1.,   1.]
974   // ]
975   auto right =
976       Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
977 
978   // left * right =
979   // [
980   //   [ 35.,  48.,  0.,  36.],
981   //   [ 21.,  32.,  15.,  30.],
982   //   [ 15.,  24.,  0.,  30.]
983   // ]
984   // y =
985   // [
986   //   [
987   //     [ 35.,  48.],
988   //     [ 0.,  36.]
989   //   ],
990   //   [
991   //     [ 21.,  32.],
992   //     [ 15.,  30.]
993   //   ],
994   //   [
995   //     [ 15.,  24.],
996   //     [ 0.,  30.]
997   //   ]
998   // ]
999   auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
1000 
1001   // out =
1002   // [
1003   //   [
1004   //     [ 35.,  48.],
1005   //     [ 21.,  32.],
1006   //     [ 15.,  24.]
1007   //   ],
1008   //   [
1009   //     [ 0.,   36.],
1010   //     [ 15.,  30.],
1011   //     [ 0.,  30.]
1012   //   ]
1013   // ]
1014   auto out =
1015       Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm)));
1016 
1017   grad_outputs->push_back(Reshape(scope, out, input_shape));
1018 
1019   // stop propagation along reduction_indices
1020   grad_outputs->push_back(NoGradient());
1021   return scope.status();
1022 }
1023 REGISTER_GRADIENT_OP("Prod", ProdGrad);
1024 
SegmentSumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1025 Status SegmentSumGrad(const Scope& scope, const Operation& op,
1026                       const std::vector<Output>& grad_inputs,
1027                       std::vector<Output>* grad_outputs) {
1028   // The SegmentSum operation sums segments of the Tensor that have the same
1029   // index in the segment_ids parameter.
1030   // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
1031   // will produce [2 + 3 + 4, 5] = [9, 5]
1032   // The gradient that will flow back to the gather operation will look like
1033   // [x1, x2], it will have the same shape as the output of the SegmentSum
1034   // operation. The differentiation step of the SegmentSum operation just
1035   // broadcast the gradient in order to retrieve the z's shape.
1036   // dy/dz = [x1, x1, x1, x2]
1037   grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
1038 
1039   // stop propagation along segment_ids
1040   grad_outputs->push_back(NoGradient());
1041   return scope.status();
1042 }
1043 REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
1044 
1045 // MatMulGrad helper function used to compute two MatMul operations
1046 // based on input matrix transposition combinations.
MatMulGradHelper(const Scope & scope,const bool is_batch,const Output & x0,const bool adj_x0,const Output & x1,const bool adj_x1,const Output & y0,const bool adj_y0,const Output & y1,const bool adj_y1,std::vector<Output> * grad_outputs)1047 Status MatMulGradHelper(const Scope& scope, const bool is_batch,
1048                         const Output& x0, const bool adj_x0, const Output& x1,
1049                         const bool adj_x1, const Output& y0, const bool adj_y0,
1050                         const Output& y1, const bool adj_y1,
1051                         std::vector<Output>* grad_outputs) {
1052   if (is_batch == false) {
1053     auto dx =
1054         MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
1055     grad_outputs->push_back(dx);
1056     auto dy =
1057         MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
1058     grad_outputs->push_back(dy);
1059   } else {
1060     auto dx =
1061         BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
1062     grad_outputs->push_back(dx);
1063     auto dy =
1064         BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
1065     grad_outputs->push_back(dy);
1066   }
1067   return scope.status();
1068 }
1069 
1070 // MatMulGrad common used to read and check node attr state, and determine
1071 // proper MatMul products for gradients based on input matrix transposition
1072 // combinations.
MatMulGradCommon(const Scope & scope,const Operation & op,const bool is_batch,const std::vector<Output> & grad_inputs,const string & attr_adj_x,const string & attr_adj_y,std::vector<Output> * grad_outputs)1073 Status MatMulGradCommon(const Scope& scope, const Operation& op,
1074                         const bool is_batch,
1075                         const std::vector<Output>& grad_inputs,
1076                         const string& attr_adj_x, const string& attr_adj_y,
1077                         std::vector<Output>* grad_outputs) {
1078   auto a = op.input(0);
1079   auto b = op.input(1);
1080   // Use conjugate of the inputs for MatMul
1081   if (is_batch == false) {
1082     a = ConjugateHelper(scope, a);
1083     b = ConjugateHelper(scope, b);
1084   }
1085   auto product = op.output(0);
1086 
1087   bool ta;
1088   bool tb;
1089   TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
1090   TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
1091 
1092   if (!ta && !tb) {
1093     return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
1094                             true, grad_inputs[0], false, grad_outputs);
1095   } else if (!ta && tb) {
1096     return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
1097                             grad_inputs[0], true, a, false, grad_outputs);
1098   } else if (ta && !tb) {
1099     return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
1100                             false, grad_inputs[0], false, grad_outputs);
1101   }
1102   return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
1103                           grad_inputs[0], true, a, true, grad_outputs);
1104 }
1105 
MatMulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1106 Status MatMulGrad(const Scope& scope, const Operation& op,
1107                   const std::vector<Output>& grad_inputs,
1108                   std::vector<Output>* grad_outputs) {
1109   return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
1110                           "transpose_b", grad_outputs);
1111 }
1112 REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
1113 
BatchMatMulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1114 Status BatchMatMulGrad(const Scope& scope, const Operation& op,
1115                        const std::vector<Output>& grad_inputs,
1116                        std::vector<Output>* grad_outputs) {
1117   return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
1118                           grad_outputs);
1119 }
1120 REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
1121 
1122 }  // anonymous namespace
1123 }  // namespace ops
1124 }  // namespace tensorflow
1125