1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define _USE_MATH_DEFINES
17 #include <cmath>
18
19 #include "tensorflow/cc/ops/array_ops_internal.h"
20 #include "tensorflow/cc/ops/math_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22
23 #include "tensorflow/cc/framework/grad_op_registry.h"
24 #include "tensorflow/cc/framework/gradients.h"
25
26 namespace tensorflow {
27 namespace ops {
28 namespace {
29
30 // Logical operations have no gradients.
31 REGISTER_NO_GRADIENT_OP("Less");
32 REGISTER_NO_GRADIENT_OP("LessEqual");
33 REGISTER_NO_GRADIENT_OP("Greater");
34 REGISTER_NO_GRADIENT_OP("GreaterEqual");
35 REGISTER_NO_GRADIENT_OP("Equal");
36 REGISTER_NO_GRADIENT_OP("ApproximateEqual");
37 REGISTER_NO_GRADIENT_OP("NotEqual");
38 REGISTER_NO_GRADIENT_OP("LogicalAnd");
39 REGISTER_NO_GRADIENT_OP("LogicalOr");
40 REGISTER_NO_GRADIENT_OP("LogicalNot");
41 REGISTER_NO_GRADIENT_OP("Floor");
42
43 // Conjugate helper function returns the conjugate of an Output if it
44 // is complex valued.
ConjugateHelper(const Scope & scope,const Output & out)45 Output ConjugateHelper(const Scope& scope, const Output& out) {
46 DataType dtype = out.type();
47 if (dtype == DT_COMPLEX64 || dtype == DT_COMPLEX128) {
48 return Conj(scope, out);
49 } else {
50 return out;
51 }
52 }
53
54 // TODO(andydavis) Add control dependencies to gradient functions (as needed).
55
AbsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)56 Status AbsGrad(const Scope& scope, const Operation& op,
57 const std::vector<Output>& grad_inputs,
58 std::vector<Output>* grad_outputs) {
59 // dx = dy * sign(x)
60 grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
61 return scope.status();
62 }
63 REGISTER_GRADIENT_OP("Abs", AbsGrad);
64
NegGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)65 Status NegGrad(const Scope& scope, const Operation& op,
66 const std::vector<Output>& grad_inputs,
67 std::vector<Output>* grad_outputs) {
68 // dx = -dy;
69 grad_outputs->push_back(Neg(scope, grad_inputs[0]));
70 return scope.status();
71 }
72 REGISTER_GRADIENT_OP("Neg", NegGrad);
73
InvGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)74 Status InvGrad(const Scope& scope, const Operation& op,
75 const std::vector<Output>& grad_inputs,
76 std::vector<Output>* grad_outputs) {
77 // Use the built-in operator.
78 grad_outputs->push_back(
79 internal::ReciprocalGrad(scope, op.output(0), grad_inputs[0]));
80 return scope.status();
81 }
82 REGISTER_GRADIENT_OP("Inv", InvGrad);
83 REGISTER_GRADIENT_OP("Reciprocal", InvGrad);
84
SquareGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)85 Status SquareGrad(const Scope& scope, const Operation& op,
86 const std::vector<Output>& grad_inputs,
87 std::vector<Output>* grad_outputs) {
88 // dy/dx = (2 * x)
89 auto two = Cast(scope, Const(scope, 2), op.input(0).type());
90 auto dydx = Mul(scope, two, op.input(0));
91 // grad(x) = grad(y) * conj(dy/dx)
92 grad_outputs->push_back(
93 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
94 return scope.status();
95 }
96 REGISTER_GRADIENT_OP("Square", SquareGrad);
97
SqrtGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)98 Status SqrtGrad(const Scope& scope, const Operation& op,
99 const std::vector<Output>& grad_inputs,
100 std::vector<Output>* grad_outputs) {
101 // Use the built-in operator.
102 grad_outputs->push_back(
103 internal::SqrtGrad(scope, op.output(0), grad_inputs[0]));
104 return scope.status();
105 }
106 REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
107
RsqrtGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)108 Status RsqrtGrad(const Scope& scope, const Operation& op,
109 const std::vector<Output>& grad_inputs,
110 std::vector<Output>* grad_outputs) {
111 // Use the built-in operator.
112 grad_outputs->push_back(
113 internal::RsqrtGrad(scope, op.output(0), grad_inputs[0]));
114 return scope.status();
115 }
116 REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
117
ExpGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)118 Status ExpGrad(const Scope& scope, const Operation& op,
119 const std::vector<Output>& grad_inputs,
120 std::vector<Output>* grad_outputs) {
121 // dy/dx = exp(x) = y
122 // grad(x) = grad(y) * conj(dy/dx)
123 // = grad(y) * conj(y)
124 grad_outputs->push_back(
125 Mul(scope, grad_inputs[0], ConjugateHelper(scope, op.output(0))));
126 return scope.status();
127 }
128 REGISTER_GRADIENT_OP("Exp", ExpGrad);
129
Expm1Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)130 Status Expm1Grad(const Scope& scope, const Operation& op,
131 const std::vector<Output>& grad_inputs,
132 std::vector<Output>* grad_outputs) {
133 // y = expm1(x)
134 // dy/dx = exp(x)
135 auto dydx = Exp(scope, op.input(0));
136 // grad(x) = grad(y) * conj(dy/dx)
137 grad_outputs->push_back(
138 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
139 return scope.status();
140 }
141 REGISTER_GRADIENT_OP("Expm1", Expm1Grad);
142
LogGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)143 Status LogGrad(const Scope& scope, const Operation& op,
144 const std::vector<Output>& grad_inputs,
145 std::vector<Output>* grad_outputs) {
146 // y = log(x)
147 // dy/dx = 1 / x
148 auto dydx = Reciprocal(scope, op.input(0));
149 // grad(x) = grad(y) * conj(dy/dx)
150 grad_outputs->push_back(
151 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
152 return scope.status();
153 }
154 REGISTER_GRADIENT_OP("Log", LogGrad);
155
Log1pGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)156 Status Log1pGrad(const Scope& scope, const Operation& op,
157 const std::vector<Output>& grad_inputs,
158 std::vector<Output>* grad_outputs) {
159 // y = log1p(x)
160 // dy/dx = 1 / (1 + x)
161 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
162 auto dydx = Reciprocal(scope, Add(scope, one, op.input(0)));
163 // grad(x) = grad(y) * conj(dy/dx)
164 grad_outputs->push_back(
165 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
166 return scope.status();
167 }
168 REGISTER_GRADIENT_OP("Log1p", Log1pGrad);
169
SinhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)170 Status SinhGrad(const Scope& scope, const Operation& op,
171 const std::vector<Output>& grad_inputs,
172 std::vector<Output>* grad_outputs) {
173 // y = sinh(x)
174 // dy/dx = cosh(x)
175 auto dydx = Cosh(scope, op.input(0));
176 // grad(x) = grad(y) * conj(dy/dx)
177 grad_outputs->push_back(
178 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
179 return scope.status();
180 }
181 REGISTER_GRADIENT_OP("Sinh", SinhGrad);
182
CoshGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)183 Status CoshGrad(const Scope& scope, const Operation& op,
184 const std::vector<Output>& grad_inputs,
185 std::vector<Output>* grad_outputs) {
186 // y = cosh(x)
187 // dy/dx = sinh(x)
188 auto dydx = Sinh(scope, op.input(0));
189 // grad(x) = grad(y) * conj(dy/dx)
190 grad_outputs->push_back(
191 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
192 return scope.status();
193 }
194 REGISTER_GRADIENT_OP("Cosh", CoshGrad);
195
TanhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)196 Status TanhGrad(const Scope& scope, const Operation& op,
197 const std::vector<Output>& grad_inputs,
198 std::vector<Output>* grad_outputs) {
199 // Use the built-in operator.
200 // Note that the built-in operator does not return the conjugate of
201 // the gradient.
202 auto grad = grad_inputs[0];
203 // Optimization to avoid calculating conj(y) until the gradient is
204 // evaluated.
205 Scope grad_scope = scope.WithControlDependencies(grad);
206 auto y = ConjugateHelper(grad_scope, op.output(0));
207 grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad));
208 return grad_scope.status();
209 }
210 REGISTER_GRADIENT_OP("Tanh", TanhGrad);
211
AsinhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)212 Status AsinhGrad(const Scope& scope, const Operation& op,
213 const std::vector<Output>& grad_inputs,
214 std::vector<Output>* grad_outputs) {
215 // y = asinh(x)
216 // dy/dx = 1 / cosh(y)
217 auto dydx = Reciprocal(scope, Cosh(scope, op.output(0)));
218 // grad(x) = grad(y) * conj(dy/dx)
219 grad_outputs->push_back(
220 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
221 return scope.status();
222 }
223 REGISTER_GRADIENT_OP("Asinh", AsinhGrad);
224
AcoshGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)225 Status AcoshGrad(const Scope& scope, const Operation& op,
226 const std::vector<Output>& grad_inputs,
227 std::vector<Output>* grad_outputs) {
228 // y = acosh(x)
229 // dy/dx = 1 / sinh(y)
230 auto dydx = Reciprocal(scope, Sinh(scope, op.output(0)));
231 // grad(x) = grad(y) * conj(dy/dx)
232 grad_outputs->push_back(
233 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
234 return scope.status();
235 }
236 REGISTER_GRADIENT_OP("Acosh", AcoshGrad);
237
AtanhGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)238 Status AtanhGrad(const Scope& scope, const Operation& op,
239 const std::vector<Output>& grad_inputs,
240 std::vector<Output>* grad_outputs) {
241 // y = atanh(x)
242 // dy/dx = 1 / (1 - x^2)
243 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
244 auto dydx = Reciprocal(scope, Sub(scope, one, Square(scope, op.input(0))));
245 // grad(x) = grad(y) * conj(dy/dx)
246 grad_outputs->push_back(
247 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
248 return scope.status();
249 }
250 REGISTER_GRADIENT_OP("Atanh", AtanhGrad);
251
SigmoidGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)252 Status SigmoidGrad(const Scope& scope, const Operation& op,
253 const std::vector<Output>& grad_inputs,
254 std::vector<Output>* grad_outputs) {
255 // Use the built-in operator.
256 // Note that the built-in operator does not return the conjugate of
257 // the gradient.
258 auto grad = grad_inputs[0];
259 // Optimization to avoid calculating conj(y) until the gradient is
260 // evaluated.
261 Scope grad_scope = scope.WithControlDependencies(grad);
262 auto y = ConjugateHelper(grad_scope, op.output(0));
263 grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad));
264 return grad_scope.status();
265 }
266 REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
267
SignGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)268 Status SignGrad(const Scope& scope, const Operation& op,
269 const std::vector<Output>& grad_inputs,
270 std::vector<Output>* grad_outputs) {
271 auto shape = Shape(scope, op.input(0));
272 auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
273 auto dx = Fill(scope, shape, zero);
274 grad_outputs->push_back(dx);
275 return scope.status();
276 }
277 REGISTER_GRADIENT_OP("Sign", SignGrad);
278
SinGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)279 Status SinGrad(const Scope& scope, const Operation& op,
280 const std::vector<Output>& grad_inputs,
281 std::vector<Output>* grad_outputs) {
282 // y = sin(x)
283 // dy/dx = cos(x)
284 auto dydx = Cos(scope, op.input(0));
285 // grad(x) = grad(y) * conj(dy/dx)
286 grad_outputs->push_back(
287 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
288 return scope.status();
289 }
290 REGISTER_GRADIENT_OP("Sin", SinGrad);
291
CosGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)292 Status CosGrad(const Scope& scope, const Operation& op,
293 const std::vector<Output>& grad_inputs,
294 std::vector<Output>* grad_outputs) {
295 // y = cos(x)
296 // dy/dx = -sin(x)
297 auto dydx = Neg(scope, Sin(scope, op.input(0)));
298 // grad(x) = grad(y) * conj(dy/dx)
299 grad_outputs->push_back(
300 Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx)));
301 return scope.status();
302 }
303 REGISTER_GRADIENT_OP("Cos", CosGrad);
304
AsinGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)305 Status AsinGrad(const Scope& scope, const Operation& op,
306 const std::vector<Output>& grad_inputs,
307 std::vector<Output>* grad_outputs) {
308 // y = asin(x)
309 // dy/dx = 1 / sqrt(1 - x^2)
310 auto x2 = Square(scope, op.input(0));
311 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
312 auto dydx = Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2)));
313 // grad(x) = grad(y) * conj(dy/dx)
314 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
315 grad_outputs->push_back(dx);
316 return scope.status();
317 }
318 REGISTER_GRADIENT_OP("Asin", AsinGrad);
319
AcosGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)320 Status AcosGrad(const Scope& scope, const Operation& op,
321 const std::vector<Output>& grad_inputs,
322 std::vector<Output>* grad_outputs) {
323 // y = acos(x)
324 // dy/dx = - 1 / (1 - x * x)^1/2
325 // dx = dy * (- 1 / (1 - x * x)^1/2)
326 auto x2 = Square(scope, op.input(0));
327 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
328 auto dydx = Neg(scope, Reciprocal(scope, Sqrt(scope, Sub(scope, one, x2))));
329 auto dx = Mul(scope, grad_inputs[0], dydx);
330 grad_outputs->push_back(dx);
331 return scope.status();
332 }
333 REGISTER_GRADIENT_OP("Acos", AcosGrad);
334
TanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)335 Status TanGrad(const Scope& scope, const Operation& op,
336 const std::vector<Output>& grad_inputs,
337 std::vector<Output>* grad_outputs) {
338 // y = tan(x)
339 // dy/dx = sec(x)^2 = 1 / cos(x)^2
340 auto dydx = Square(scope, Reciprocal(scope, Cos(scope, op.input(0))));
341 // grad(x) = grad(y) * conj(dy/dx)
342 auto dx = Mul(scope, grad_inputs[0], ConjugateHelper(scope, dydx));
343 grad_outputs->push_back(dx);
344 return scope.status();
345 }
346 REGISTER_GRADIENT_OP("Tan", TanGrad);
347
AtanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)348 Status AtanGrad(const Scope& scope, const Operation& op,
349 const std::vector<Output>& grad_inputs,
350 std::vector<Output>* grad_outputs) {
351 // y = arctan(x)
352 // dy/dx = 1 / (1 + x^2)
353 // dx = dy * (1 / (1 + x^2)
354 auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
355 auto dydx = Reciprocal(scope, Add(scope, one, Square(scope, op.input(0))));
356 auto dx = Mul(scope, grad_inputs[0], dydx);
357 grad_outputs->push_back(dx);
358 return scope.status();
359 }
360 REGISTER_GRADIENT_OP("Atan", AtanGrad);
361
362 // BinaryGradCommon handles the setup for binary ops that broadcast
363 // their inputs.
BinaryGradCommon(const Scope & scope,const Operation & op,std::vector<Output> * grad_outputs,const Output & gx_1,const Output & gx_2)364 Status BinaryGradCommon(const Scope& scope, const Operation& op,
365 std::vector<Output>* grad_outputs, const Output& gx_1,
366 const Output& gx_2) {
367 auto sx_1 = Shape(scope, op.input(0));
368 auto sx_2 = Shape(scope, op.input(1));
369 auto rx = internal::BroadcastGradientArgs(scope, sx_1, sx_2);
370 auto dx_1 = Reshape(scope, Sum(scope, gx_1, rx.r0), sx_1);
371 auto dx_2 = Reshape(scope, Sum(scope, gx_2, rx.r1), sx_2);
372 grad_outputs->push_back(dx_1);
373 grad_outputs->push_back(dx_2);
374 return scope.status();
375 }
376
AddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)377 Status AddGrad(const Scope& scope, const Operation& op,
378 const std::vector<Output>& grad_inputs,
379 std::vector<Output>* grad_outputs) {
380 // y = x_1 + x_2
381 // dy/dx_1 = dy/dx_2 = 1
382 auto gx_1 = Identity(scope, grad_inputs[0]);
383 auto gx_2 = Identity(scope, grad_inputs[0]);
384 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
385 }
386 REGISTER_GRADIENT_OP("Add", AddGrad);
387
SubGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)388 Status SubGrad(const Scope& scope, const Operation& op,
389 const std::vector<Output>& grad_inputs,
390 std::vector<Output>* grad_outputs) {
391 // y = x_1 - x_2
392 // dy/dx_1 = 1
393 // dy/dx_2 = -1
394 auto gx_1 = Identity(scope, grad_inputs[0]);
395 auto gx_2 = Neg(scope, grad_inputs[0]);
396 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
397 }
398 REGISTER_GRADIENT_OP("Sub", SubGrad);
399
MulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)400 Status MulGrad(const Scope& scope, const Operation& op,
401 const std::vector<Output>& grad_inputs,
402 std::vector<Output>* grad_outputs) {
403 auto x_1 = ConjugateHelper(scope, op.input(0));
404 auto x_2 = ConjugateHelper(scope, op.input(1));
405 // y = x_1 * x_2
406 // dy/dx_1 = x_2
407 // dy/dx_2 = x_1
408 auto gx_1 = Mul(scope, grad_inputs[0], x_2);
409 auto gx_2 = Mul(scope, grad_inputs[0], x_1);
410 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
411 }
412 REGISTER_GRADIENT_OP("Mul", MulGrad);
413
DivGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)414 Status DivGrad(const Scope& scope, const Operation& op,
415 const std::vector<Output>& grad_inputs,
416 std::vector<Output>* grad_outputs) {
417 auto x_1 = ConjugateHelper(scope, op.input(0));
418 auto x_2 = ConjugateHelper(scope, op.input(1));
419 // y = x_1 / x_2
420 // dy/dx_1 = 1/x_2
421 // dy/dx_2 = -x_1/x_2^2
422 auto gx_1 = Div(scope, grad_inputs[0], x_2);
423 auto gx_2 = Mul(scope, grad_inputs[0],
424 Div(scope, Div(scope, Neg(scope, x_1), x_2), x_2));
425 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
426 }
427 REGISTER_GRADIENT_OP("Div", DivGrad);
428
RealDivGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)429 Status RealDivGrad(const Scope& scope, const Operation& op,
430 const std::vector<Output>& grad_inputs,
431 std::vector<Output>* grad_outputs) {
432 auto x_1 = ConjugateHelper(scope, op.input(0));
433 auto x_2 = ConjugateHelper(scope, op.input(1));
434 // y = x_1 / x_2
435 // dy/dx_1 = 1/x_2
436 // dy/dx_2 = -x_1/x_2^2
437 auto gx_1 = RealDiv(scope, grad_inputs[0], x_2);
438 auto gx_2 = Mul(scope, grad_inputs[0],
439 RealDiv(scope, RealDiv(scope, Neg(scope, x_1), x_2), x_2));
440 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
441 }
442 REGISTER_GRADIENT_OP("RealDiv", RealDivGrad);
443
DivNoNanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)444 Status DivNoNanGrad(const Scope& scope, const Operation& op,
445 const std::vector<Output>& grad_inputs,
446 std::vector<Output>* grad_outputs) {
447 auto x_1 = ConjugateHelper(scope, op.input(0));
448 auto x_2 = ConjugateHelper(scope, op.input(1));
449 // y = x_1 / x_2
450 // dy/dx_1 = 1/x_2
451 // dy/dx_2 = -x_1/x_2^2
452 auto gx_1 = DivNoNan(scope, grad_inputs[0], x_2);
453 auto gx_2 = Mul(scope, grad_inputs[0],
454 DivNoNan(scope, DivNoNan(scope, Neg(scope, x_1), x_2), x_2));
455 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
456 }
457 REGISTER_GRADIENT_OP("DivNoNan", DivNoNanGrad);
458
SquaredDifferenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)459 Status SquaredDifferenceGrad(const Scope& scope, const Operation& op,
460 const std::vector<Output>& grad_inputs,
461 std::vector<Output>* grad_outputs) {
462 auto x_1 = ConjugateHelper(scope, op.input(0));
463 auto x_2 = ConjugateHelper(scope, op.input(1));
464 // y = (x_1 - x_2)^2
465 // dy/dx_1 = 2 * (x_1 - x_2)
466 // dy/dx_2 = -2 * (x_1 - x_2)
467 auto two = Cast(scope, Const(scope, 2), grad_inputs[0].type());
468 auto gx_1 = Mul(scope, grad_inputs[0], Mul(scope, two, Sub(scope, x_1, x_2)));
469 auto gx_2 = Neg(scope, gx_1);
470 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
471 }
472 REGISTER_GRADIENT_OP("SquaredDifference", SquaredDifferenceGrad);
473
AddNGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)474 Status AddNGrad(const Scope& scope, const Operation& op,
475 const std::vector<Output>& grad_inputs,
476 std::vector<Output>* grad_outputs) {
477 // AddN doesn't support broadcasting, so all the inputs must be the
478 // same shape.
479 // Note:
480 // dy/dx_k = d(x_1 + x_2 + ... + x_n)/dx_k = 1 for all x_k
481 // hence dx_k = dy for all x_k
482 // So the gradient for AddN just transfers the incoming gradient to
483 // all outgoing gradients.
484 auto incoming = Identity(scope, grad_inputs[0]);
485 for (int32 i = 0; i < op.num_inputs(); ++i) {
486 grad_outputs->push_back(incoming);
487 }
488 return scope.status();
489 }
490 REGISTER_GRADIENT_OP("AddN", AddNGrad);
491
PowGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)492 Status PowGrad(const Scope& scope, const Operation& op,
493 const std::vector<Output>& grad_inputs,
494 std::vector<Output>* grad_outputs) {
495 auto x = ConjugateHelper(scope, op.input(0));
496 auto y = ConjugateHelper(scope, op.input(1));
497 auto z = ConjugateHelper(scope, op.output(0));
498 auto grad = grad_inputs[0];
499 // grad * y * pow(x, y - 1)
500 auto one = Cast(scope, Const(scope, 1.0), y.type());
501 auto gx_1 = Mul(scope,
502 Mul(scope, grad, y),
503 Pow(scope, x, Sub(scope, y, one)));
504 // Avoid false singularity at x = 0
505 DataType x_dtype = x.type();
506 auto zero = Cast(scope, Const(scope, 0.0), x_dtype);
507 if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) {
508 // real(x) < 0 is fine for the complex case
509 auto log_x = Where3(scope,
510 NotEqual(scope, x, zero),
511 Log(scope, x),
512 ZerosLike(scope, x));
513 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
514 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
515 } else {
516 // There's no sensible real value to return if x < 0, so return 0
517 auto log_x = Where3(scope,
518 Greater(scope, x, zero),
519 Log(scope, x),
520 ZerosLike(scope, x));
521 auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x);
522 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1);
523 }
524 }
525 REGISTER_GRADIENT_OP("Pow", PowGrad);
526
527 // MaximumMinimumGradCommon adds shared ops to calculate gradients for
528 // the binary Maximum and Minimum ops.
MaximumMinimumGradCommon(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs,const Output & comparator)529 Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op,
530 const std::vector<Output>& grad_inputs,
531 std::vector<Output>* grad_outputs,
532 const Output& comparator) {
533 // comparator is a boolean tensor, with
534 // y = x_1 at points where comparator is true, and x_2 otherwise
535 // Therefore
536 // dy/dx_1 = 1 where comparator is true, and 0 otherwise.
537 // dy/dx_2 = 0 where comparator is true, and 1 otherwise.
538 auto grad = grad_inputs[0];
539 auto zeros = ZerosLike(scope, grad);
540 auto gx_1 = Where3(scope, comparator, grad, zeros);
541 auto gx_2 = Where3(scope, comparator, zeros, grad);
542 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
543 }
544
MaximumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)545 Status MaximumGrad(const Scope& scope, const Operation& op,
546 const std::vector<Output>& grad_inputs,
547 std::vector<Output>* grad_outputs) {
548 auto comparator = GreaterEqual(scope, op.input(0), op.input(1));
549 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
550 comparator);
551 }
552 REGISTER_GRADIENT_OP("Maximum", MaximumGrad);
553
MinimumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)554 Status MinimumGrad(const Scope& scope, const Operation& op,
555 const std::vector<Output>& grad_inputs,
556 std::vector<Output>* grad_outputs) {
557 auto comparator = LessEqual(scope, op.input(0), op.input(1));
558 return MaximumMinimumGradCommon(scope, op, grad_inputs, grad_outputs,
559 comparator);
560 }
561 REGISTER_GRADIENT_OP("Minimum", MinimumGrad);
562
RealGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)563 Status RealGrad(const Scope& scope, const Operation& op,
564 const std::vector<Output>& grad_inputs,
565 std::vector<Output>* grad_outputs) {
566 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
567 auto dx = Complex(scope, grad_inputs[0], zero);
568 grad_outputs->push_back(dx);
569 return scope.status();
570 }
571 REGISTER_GRADIENT_OP("Real", RealGrad);
572
ImagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)573 Status ImagGrad(const Scope& scope, const Operation& op,
574 const std::vector<Output>& grad_inputs,
575 std::vector<Output>* grad_outputs) {
576 auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
577 auto dx = Complex(scope, zero, grad_inputs[0]);
578 grad_outputs->push_back(dx);
579 return scope.status();
580 }
581 REGISTER_GRADIENT_OP("Imag", ImagGrad);
582
ComplexGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)583 Status ComplexGrad(const Scope& scope, const Operation& op,
584 const std::vector<Output>& grad_inputs,
585 std::vector<Output>* grad_outputs) {
586 auto gx_1 = Real(scope, grad_inputs[0]);
587 auto gx_2 = Imag(scope, grad_inputs[0]);
588 return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2);
589 }
590 REGISTER_GRADIENT_OP("Complex", ComplexGrad);
591
AngleGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)592 Status AngleGrad(const Scope& scope, const Operation& op,
593 const std::vector<Output>& grad_inputs,
594 std::vector<Output>* grad_outputs) {
595 // y = Angle(x)
596 // dx = -dy / (Im(x) + iRe(x)) = -dy * z
597 auto re = Real(scope, op.input(0));
598 auto im = Imag(scope, op.input(0));
599 auto z_inv = Reciprocal(scope, Complex(scope, im, re));
600 auto zero = Cast(scope, Const(scope, 0), grad_inputs[0].type());
601 auto grad = Complex(scope, grad_inputs[0], zero);
602 auto dx = Neg(scope, Mul(scope, grad, z_inv));
603 grad_outputs->push_back(dx);
604 return scope.status();
605 }
606 REGISTER_GRADIENT_OP("Angle", AngleGrad);
607
ConjGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)608 Status ConjGrad(const Scope& scope, const Operation& op,
609 const std::vector<Output>& grad_inputs,
610 std::vector<Output>* grad_outputs) {
611 grad_outputs->push_back(Conj(scope, grad_inputs[0]));
612 return scope.status();
613 }
614 REGISTER_GRADIENT_OP("Conj", ConjGrad);
615
616 // Integer division x / y, assuming x and y >=0, but treats x/0 = x
SafeDivHelper(const Scope & scope,const Output & x,const Output & y)617 Output SafeDivHelper(const Scope& scope, const Output& x, const Output& y) {
618 return Div(scope, x, Maximum(scope, y, Const(scope, 1)));
619 }
620
621 // Helper function for reduction ops.
622 //
623 // input_shape: 1-D Tensor, the shape of the Tensor being reduced.
624 // axes: 1-D Tensor, the reduction axes.
625 // Note that the reduction indices are in the range
626 // -rank(input_shape), rank(input_shape)
627 // returns a 1-D Tensor, the output shape as if keep_dims were set to True.
ReducedShapeHelper(const Scope & scope,const Output & input_shape,const Output & reduction_axes)628 Output ReducedShapeHelper(const Scope& scope, const Output& input_shape,
629 const Output& reduction_axes) {
630 auto zero = Const(scope, 0);
631 auto one = Const(scope, 1);
632
633 // Running example in comments
634 // input_shape = [2, 3, 5, 7]
635 // axes = [1, 2]
636 // The result (a shape after a reduction with keep_dims=True)
637 // [2, 1, 1, 7]
638 //
639 // We can treat each entry in axes as an index into input_shape that
640 // should be replaced by 1.
641 // We use DynamicStitch to do this.
642
643 // input_rank = 4
644 auto input_rank = Size(scope, input_shape);
645
646 // Normalize any negative indices in the reduction_axes to positive
647 // values.
648 auto axes = Mod(scope, Add(scope, reduction_axes, input_rank), input_rank);
649
650 // This [0..input_rank) range of integers is used in DynamicStitch to
651 // first copy input_shape to the result.
652 // input_rank_range = [0, 1, 2, 3]
653 auto input_rank_range = Range(scope, zero, input_rank, one);
654
655 // A 1-filled tensor with the same shape as axes. DynamicStitch will
656 // merge these 1s (using axes for indices) to the correct
657 // position in the result.
658 // axes_ones = [1, 1]
659 auto axes_ones = OnesLike(scope, axes);
660
661 // using DynamicStitch:
662 // indices = { input_rank_range, axes }
663 // = { [0, 1, 2, 3], [1, 2] }
664 // data = { input_shape, axes_ones }
665 // = { [2, 3, 5, 7], [1, 1] }
666 // The input_rank_range entry in indices first replicates the
667 // input_shape to the result.
668 // The axes entry in indices then moves a 1 to each of its entries,
669 // resulting in
670 // [2, 1, 1, 7]
671 std::vector<Output> indices = {input_rank_range, axes};
672 std::vector<Output> data = {input_shape, axes_ones};
673 return DynamicStitch(scope, indices, data);
674 }
675
676 // SumGradHelper returns the gradient for the Sum operator, and is used
677 // by SumGrad and MeanGrad.
SumGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs)678 Output SumGradHelper(const Scope& scope, const Operation& op,
679 const std::vector<Output>& grad_inputs) {
680 // The partial derivative for any input along a "reduced" dimension
681 // is just 1, so we only need replicate the output gradient on such a
682 // dimension to its "expanded" shape.
683 // Running example:
684 // input is
685 // [[a, b, c],
686 // [d, e, f]]
687 // reduction_indices = [1]
688 // Sum = [a + b + c, d + e + f]
689 // if the gradient is [g1, g2]
690 // We want the propagated gradient to be
691 // [[g1, g1, g1],
692 // [g2, g2, g2]]
693
694 // input_shape = [2, 3]
695 auto input_shape = Shape(scope, op.input(0));
696
697 // output_shape_kept_dims = [2, 1]
698 auto output_shape_kept_dims =
699 ReducedShapeHelper(scope, input_shape, op.input(1));
700
701 // This step "flips" any 1s with values from the input_shape, and
702 // replaces remaining entries with 1. This creates a shape that
703 // shows how much each dimension in the incoming gradient should be
704 // replicated.
705 // tile_scaling = [1, 3]
706 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
707
708 // grad = [[g1], [g2]]
709 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
710
711 // tile(grad, tile_scaling) = [[g1, g1, g1], [g2, g2, g2]]
712 return Tile(scope, grad, tile_scaling);
713 }
714
SumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)715 Status SumGrad(const Scope& scope, const Operation& op,
716 const std::vector<Output>& grad_inputs,
717 std::vector<Output>* grad_outputs) {
718 grad_outputs->push_back(SumGradHelper(scope, op, grad_inputs));
719
720 // Stop propagation along reduction_indices
721 grad_outputs->push_back(NoGradient());
722 return scope.status();
723 }
724 REGISTER_GRADIENT_OP("Sum", SumGrad);
725
MeanGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)726 Status MeanGrad(const Scope& scope, const Operation& op,
727 const std::vector<Output>& grad_inputs,
728 std::vector<Output>* grad_outputs) {
729 // The Mean gradient is just like the Sum gradient, except that
730 // all gradients are also divided by the size of reduced groups.
731 auto sum_grad = SumGradHelper(scope, op, grad_inputs);
732
733 // The product of all entries in a tensor's shape is the total
734 // number of entries in the tensor. This step calculates
735 // n_input_entries/n_output_entries
736 // = group_size
737 auto input_shape = Shape(scope, op.input(0));
738 auto output_shape = Shape(scope, op.output(0));
739 auto zero = Const(scope, 0);
740 auto group_size = SafeDivHelper(scope, Prod(scope, input_shape, zero),
741 Prod(scope, output_shape, zero));
742
743 // propagate sum_grad/group_size
744 grad_outputs->push_back(
745 Div(scope, sum_grad, Cast(scope, group_size, sum_grad.type())));
746
747 // Stop propagation along reduction_indices
748 grad_outputs->push_back(NoGradient());
749 return scope.status();
750 }
751 REGISTER_GRADIENT_OP("Mean", MeanGrad);
752
ErfGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)753 Status ErfGrad(const Scope& scope, const Operation& op,
754 const std::vector<Output>& grad_inputs,
755 std::vector<Output>* grad_outputs) {
756 auto grad = grad_inputs[0];
757 auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)),
758 grad.type());
759 Scope grad_scope = scope.WithControlDependencies(grad);
760 auto x = ConjugateHelper(grad_scope, op.input(0));
761 // grad * 2/sqrt(pi) * exp(-x**2)
762 auto dx = Mul(grad_scope,
763 Mul(grad_scope, grad, two_over_root_pi),
764 Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x))));
765 grad_outputs->push_back(dx);
766 return grad_scope.status();
767 }
768 REGISTER_GRADIENT_OP("Erf", ErfGrad);
769
LgammaGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)770 Status LgammaGrad(const Scope& scope, const Operation& op,
771 const std::vector<Output>& grad_inputs,
772 std::vector<Output>* grad_outputs) {
773 auto grad = grad_inputs[0];
774 Scope grad_scope = scope.WithControlDependencies(grad);
775 auto x = ConjugateHelper(grad_scope, op.input(0));
776 auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x));
777 grad_outputs->push_back(dx);
778 return grad_scope.status();
779 }
780 REGISTER_GRADIENT_OP("Lgamma", LgammaGrad);
781
MinOrMaxGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)782 Status MinOrMaxGrad(const Scope& scope, const Operation& op,
783 const std::vector<Output>& grad_inputs,
784 std::vector<Output>* grad_outputs) {
785 // The partial derivative for any input along a "reduced" dimension
786 // is 1 when it is the min (or max) and 0 everywhere else. So the
787 // gradient calculation is identical for both operators.
788 //
789 // There's a special case for propagating gradients when there are
790 // multiple minima (or maxima) - we choose to divide the gradient
791 // equally among all matching inputs.
792 //
793 // Please note this comment
794 // https://github.com/tensorflow/tensorflow/issues/4886#issuecomment-256836063
795 // for details.
796
797 // Running example:
798 // input: [[5, 5, 5],
799 // [1, 2, -3]]
800 // reduction_indices: [1]
801 auto input = op.input(0);
802 auto reduction_indices = op.input(1);
803
804 // [2, 3]
805 auto input_shape = Shape(scope, input);
806
807 // [2, 1]
808 auto output_shape_kept_dims =
809 ReducedShapeHelper(scope, input_shape, reduction_indices);
810
811 // for op=min (say)
812 // output = [5, -3]
813 // y = [[5],
814 // [-3]]
815 auto y = Reshape(scope, op.output(0), output_shape_kept_dims);
816
817 // reshape([g1, g2], [2, 1]) = [[g1],
818 // [g2]]
819 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
820
821 // indicators = equal(y, input)
822 // = equal([[5], [[5, 5, 5],
823 // [-3]], [1, 2, -3]])
824 // = [[1, 1, 1],
825 // [0, 0, 1]]
826 auto indicators = Cast(scope, Equal(scope, y, input), grad_inputs[0].type());
827
828 // [[3],
829 // [1]]
830 auto num_selected = Reshape(scope, Sum(scope, indicators, reduction_indices),
831 output_shape_kept_dims);
832
833 // [[1/3, 1/3, 1/3],
834 // [0, 0, 1]]
835 auto scale = Div(scope, indicators, num_selected);
836
837 // [[g1/3, g1/3, g1/3],
838 // [0, 0, g2]]
839 grad_outputs->push_back(Mul(scope, scale, grad));
840
841 // Stop propagation along reduction_indices
842 grad_outputs->push_back(NoGradient());
843 return scope.status();
844 }
845 REGISTER_GRADIENT_OP("Min", MinOrMaxGrad);
846 REGISTER_GRADIENT_OP("Max", MinOrMaxGrad);
847
ProdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)848 Status ProdGrad(const Scope& scope, const Operation& op,
849 const std::vector<Output>& grad_inputs,
850 std::vector<Output>* grad_outputs) {
851 auto zero = Const(scope, 0);
852 auto one = Const(scope, 1);
853
854 // The gradient can be expressed by dividing the product by each entry of
855 // the input tensor. If our input is
856 // [
857 // [3, 4],
858 // [5, 6],
859 // [7, 8]
860 // ]
861 // and we do a Prod operation on the axis 1, we will obtain [[105, 192]].
862 // The gradient will have the same shape as the input
863 // [
864 // [105/3, 192/4],
865 // dz * [105/5, 192/6],
866 // [105/7, 192/6]
867 // ]
868 // If the input contains a zero, the division is impossible but
869 // if we take the calculation that gave the first gradient
870 // (3 * 5 * 6)/3 is equal to 5 * 6
871 // the trick will be to cumprod the elements on the axis without
872 // the element at the current position (3 in the example above).
873 // We will take as example:
874 // [
875 // [
876 // [3.0, 4.0],
877 // [5.0, 6.0],
878 // [7.0, 8.0]
879 // ],
880 // [
881 // [3.0, 5.0],
882 // [0.0, 6.0],
883 // [5.0, 6.0]
884 // ]
885 // ]
886
887 // [2, 3, 2]
888 auto input_shape = Shape(scope, op.input(0));
889
890 // The Reshape with -1 flattens the reduction indices.
891 // [1]
892 auto reduction_indices = Reshape(scope, op.input(1), {-1});
893
894 // [2, 1, 2]
895 auto output_shape_kept_dims =
896 ReducedShapeHelper(scope, input_shape, reduction_indices);
897
898 // [1, 3, 1]
899 auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims);
900
901 // [[[105, 192]], [[0, 180]]]
902 auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims);
903
904 // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]]
905 auto grad_tiled = Tile(scope, grad, tile_scaling);
906
907 Scope cpu_scope = scope.WithDevice("/cpu:0");
908
909 // [3]
910 auto rank = Rank(cpu_scope, op.input(0));
911
912
913 // Normalize any negative indices in the reduction_axes to positive values.
914 auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank);
915
916 // [1]
917 auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32);
918
919 // [0, 1, 2]
920 auto idx = Range(cpu_scope, zero, rank, one);
921
922 // [0, 2]
923 auto other = SetDiff1D(cpu_scope, idx, reduced).out;
924
925 // [1, 0, 2]
926 auto perm =
927 Concat(cpu_scope, std::initializer_list<Input>{reduced, other}, 0);
928
929 // 3 => [3]
930 auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0);
931
932 // 2 * 2 => [2]
933 auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0);
934
935 // [
936 // [
937 // [ 3., 4.],
938 // [ 3., 5.]
939 // ],
940 // [
941 // [ 5., 6.],
942 // [ 0., 6.]
943 // ],
944 // [
945 // [ 7., 8.],
946 // [ 5., 6.]
947 // ]
948 // ]
949 auto permuted = Transpose(scope, op.input(0), perm);
950
951 // [3, 2, 2]
952 auto permuted_shape = Shape(scope, permuted);
953
954 // [
955 // [ 3., 4., 3., 5.],
956 // [ 5., 6., 0., 6.],
957 // [ 7., 8., 5., 6.]
958 // ]
959 auto reshaped = Reshape(
960 scope, permuted,
961 Stack(scope, std::initializer_list<Input>{reduced_num, other_num}));
962
963 // [
964 // [ 1., 1., 1., 1.],
965 // [ 3., 4., 3., 5.],
966 // [ 15., 24., 0., 30.]
967 // ]
968 auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true));
969
970 // [
971 // [ 35., 48., 0., 36.],
972 // [ 7., 8., 5., 6.],
973 // [ 1., 1., 1., 1.]
974 // ]
975 auto right =
976 Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true));
977
978 // left * right =
979 // [
980 // [ 35., 48., 0., 36.],
981 // [ 21., 32., 15., 30.],
982 // [ 15., 24., 0., 30.]
983 // ]
984 // y =
985 // [
986 // [
987 // [ 35., 48.],
988 // [ 0., 36.]
989 // ],
990 // [
991 // [ 21., 32.],
992 // [ 15., 30.]
993 // ],
994 // [
995 // [ 15., 24.],
996 // [ 0., 30.]
997 // ]
998 // ]
999 auto y = Reshape(scope, Mul(scope, left, right), permuted_shape);
1000
1001 // out =
1002 // [
1003 // [
1004 // [ 35., 48.],
1005 // [ 21., 32.],
1006 // [ 15., 24.]
1007 // ],
1008 // [
1009 // [ 0., 36.],
1010 // [ 15., 30.],
1011 // [ 0., 30.]
1012 // ]
1013 // ]
1014 auto out =
1015 Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm)));
1016
1017 grad_outputs->push_back(Reshape(scope, out, input_shape));
1018
1019 // stop propagation along reduction_indices
1020 grad_outputs->push_back(NoGradient());
1021 return scope.status();
1022 }
1023 REGISTER_GRADIENT_OP("Prod", ProdGrad);
1024
SegmentSumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1025 Status SegmentSumGrad(const Scope& scope, const Operation& op,
1026 const std::vector<Output>& grad_inputs,
1027 std::vector<Output>* grad_outputs) {
1028 // The SegmentSum operation sums segments of the Tensor that have the same
1029 // index in the segment_ids parameter.
1030 // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
1031 // will produce [2 + 3 + 4, 5] = [9, 5]
1032 // The gradient that will flow back to the gather operation will look like
1033 // [x1, x2], it will have the same shape as the output of the SegmentSum
1034 // operation. The differentiation step of the SegmentSum operation just
1035 // broadcast the gradient in order to retrieve the z's shape.
1036 // dy/dz = [x1, x1, x1, x2]
1037 grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
1038
1039 // stop propagation along segment_ids
1040 grad_outputs->push_back(NoGradient());
1041 return scope.status();
1042 }
1043 REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
1044
1045 // MatMulGrad helper function used to compute two MatMul operations
1046 // based on input matrix transposition combinations.
MatMulGradHelper(const Scope & scope,const bool is_batch,const Output & x0,const bool adj_x0,const Output & x1,const bool adj_x1,const Output & y0,const bool adj_y0,const Output & y1,const bool adj_y1,std::vector<Output> * grad_outputs)1047 Status MatMulGradHelper(const Scope& scope, const bool is_batch,
1048 const Output& x0, const bool adj_x0, const Output& x1,
1049 const bool adj_x1, const Output& y0, const bool adj_y0,
1050 const Output& y1, const bool adj_y1,
1051 std::vector<Output>* grad_outputs) {
1052 if (is_batch == false) {
1053 auto dx =
1054 MatMul(scope, x0, x1, MatMul::TransposeA(adj_x0).TransposeB(adj_x1));
1055 grad_outputs->push_back(dx);
1056 auto dy =
1057 MatMul(scope, y0, y1, MatMul::TransposeA(adj_y0).TransposeB(adj_y1));
1058 grad_outputs->push_back(dy);
1059 } else {
1060 auto dx =
1061 BatchMatMul(scope, x0, x1, BatchMatMul::AdjX(adj_x0).AdjY(adj_x1));
1062 grad_outputs->push_back(dx);
1063 auto dy =
1064 BatchMatMul(scope, y0, y1, BatchMatMul::AdjX(adj_y0).AdjY(adj_y1));
1065 grad_outputs->push_back(dy);
1066 }
1067 return scope.status();
1068 }
1069
1070 // MatMulGrad common used to read and check node attr state, and determine
1071 // proper MatMul products for gradients based on input matrix transposition
1072 // combinations.
MatMulGradCommon(const Scope & scope,const Operation & op,const bool is_batch,const std::vector<Output> & grad_inputs,const string & attr_adj_x,const string & attr_adj_y,std::vector<Output> * grad_outputs)1073 Status MatMulGradCommon(const Scope& scope, const Operation& op,
1074 const bool is_batch,
1075 const std::vector<Output>& grad_inputs,
1076 const string& attr_adj_x, const string& attr_adj_y,
1077 std::vector<Output>* grad_outputs) {
1078 auto a = op.input(0);
1079 auto b = op.input(1);
1080 // Use conjugate of the inputs for MatMul
1081 if (is_batch == false) {
1082 a = ConjugateHelper(scope, a);
1083 b = ConjugateHelper(scope, b);
1084 }
1085 auto product = op.output(0);
1086
1087 bool ta;
1088 bool tb;
1089 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_x, &ta));
1090 TF_RETURN_IF_ERROR(GetNodeAttr(product.node()->attrs(), attr_adj_y, &tb));
1091
1092 if (!ta && !tb) {
1093 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, true, a,
1094 true, grad_inputs[0], false, grad_outputs);
1095 } else if (!ta && tb) {
1096 return MatMulGradHelper(scope, is_batch, grad_inputs[0], false, b, false,
1097 grad_inputs[0], true, a, false, grad_outputs);
1098 } else if (ta && !tb) {
1099 return MatMulGradHelper(scope, is_batch, b, false, grad_inputs[0], true, a,
1100 false, grad_inputs[0], false, grad_outputs);
1101 }
1102 return MatMulGradHelper(scope, is_batch, b, true, grad_inputs[0], true,
1103 grad_inputs[0], true, a, true, grad_outputs);
1104 }
1105
MatMulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1106 Status MatMulGrad(const Scope& scope, const Operation& op,
1107 const std::vector<Output>& grad_inputs,
1108 std::vector<Output>* grad_outputs) {
1109 return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
1110 "transpose_b", grad_outputs);
1111 }
1112 REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
1113
BatchMatMulGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)1114 Status BatchMatMulGrad(const Scope& scope, const Operation& op,
1115 const std::vector<Output>& grad_inputs,
1116 std::vector<Output>* grad_outputs) {
1117 return MatMulGradCommon(scope, op, true, grad_inputs, "adj_x", "adj_y",
1118 grad_outputs);
1119 }
1120 REGISTER_GRADIENT_OP("BatchMatMul", BatchMatMulGrad);
1121
1122 } // anonymous namespace
1123 } // namespace ops
1124 } // namespace tensorflow
1125