• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/ops/array_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 
24 namespace tensorflow {
25 namespace ops {
26 namespace {
27 
28 REGISTER_NO_GRADIENT_OP("Const");
29 REGISTER_NO_GRADIENT_OP("StopGradient");
30 REGISTER_NO_GRADIENT_OP("ConcatOffset");
31 REGISTER_NO_GRADIENT_OP("EditDistance");
32 REGISTER_NO_GRADIENT_OP("ZerosLike");
33 REGISTER_NO_GRADIENT_OP("InvertPermutation");
34 REGISTER_NO_GRADIENT_OP("Shape");
35 REGISTER_NO_GRADIENT_OP("ShapeN");
36 REGISTER_NO_GRADIENT_OP("Rank");
37 REGISTER_NO_GRADIENT_OP("Size");
38 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39 REGISTER_NO_GRADIENT_OP("OneHot");
40 
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)41 Status PackGrad(const Scope& scope, const Operation& op,
42                 const std::vector<Output>& grad_inputs,
43                 std::vector<Output>* grad_outputs) {
44   int N;
45   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
46   int axis;
47   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
48 
49   grad_outputs->reserve(N);
50   auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51   for (const Output& o : grad_op.output) {
52     grad_outputs->emplace_back(o);
53   }
54   return scope.status();
55 }
56 REGISTER_GRADIENT_OP("Pack", PackGrad);
57 
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)58 Status UnpackGrad(const Scope& scope, const Operation& op,
59                   const std::vector<Output>& grad_inputs,
60                   std::vector<Output>* grad_outputs) {
61   int axis;
62   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
63   grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64   return scope.status();
65 }
66 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67 
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)68 Status IdentityGrad(const Scope& scope, const Operation& op,
69                     const std::vector<Output>& grad_inputs,
70                     std::vector<Output>* grad_outputs) {
71   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72   return scope.status();
73 }
74 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75 
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status RefIdentityGrad(const Scope& scope, const Operation& op,
77                        const std::vector<Output>& grad_inputs,
78                        std::vector<Output>* grad_outputs) {
79   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80   return scope.status();
81 }
82 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83 
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)84 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85                                  const std::vector<Output>& grad_inputs,
86                                  std::vector<Output>* grad_outputs) {
87   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88   return scope.status();
89 }
90 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91 
QuantizeAndDequantizeV4GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)92 Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
93                                          const Operation& op,
94                                          const std::vector<Output>& grad_inputs,
95                                          std::vector<Output>* grad_outputs) {
96   Input input = Shape(scope, op.input(0));
97   Input input_min = op.input(1);
98   Input input_max = op.input(2);
99   int64_t axis;
100   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
101   auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
102       scope, grad_inputs[0], input, input_min, input_max,
103       QuantizeAndDequantizeV4Grad::Axis(axis));
104   grad_outputs->push_back(qdq_v4_grad.input_backprop);
105   grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
106   grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
107   return scope.status();
108 }
109 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
110                      QuantizeAndDequantizeV4GradHelper);
111 
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)112 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
113                                    const std::vector<Output>& grad_inputs,
114                                    std::vector<Output>* grad_outputs) {
115   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
116   grad_outputs->push_back(NoGradient());
117   grad_outputs->push_back(NoGradient());
118   grad_outputs->push_back(NoGradient());
119   return scope.status();
120 }
121 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
122 
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status SplitGrad(const Scope& scope, const Operation& op,
124                  const std::vector<Output>& grad_inputs,
125                  std::vector<Output>* grad_outputs) {
126   grad_outputs->push_back(NoGradient());
127   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
128   return scope.status();
129 }
130 REGISTER_GRADIENT_OP("Split", SplitGrad);
131 
SplitVGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)132 Status SplitVGrad(const Scope& scope, const Operation& op,
133                   const std::vector<Output>& grad_inputs,
134                   std::vector<Output>* grad_outputs) {
135   if (op.num_inputs() < 3) {
136     return errors::InvalidArgument("SplitV requires 3 arguments");
137   }
138   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(2)));
139   for (int i = 0; i < op.num_inputs() - 1; ++i) {
140     grad_outputs->push_back(NoGradient());
141   }
142   return scope.status();
143 }
144 REGISTER_GRADIENT_OP("SplitV", SplitVGrad);
145 
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)146 Status FillGrad(const Scope& scope, const Operation& op,
147                 const std::vector<Output>& grad_inputs,
148                 std::vector<Output>* grad_outputs) {
149   // y = fill(fill_shape, x)
150   // No gradient returned for the fill_shape argument.
151   grad_outputs->push_back(NoGradient());
152   // The gradient for x (which must be a scalar) is just the sum of
153   // all the gradients from the shape it fills.
154   // We use ReduceSum to implement this, which needs an argument providing
155   // the indices of all the dimensions of the incoming gradient.
156   // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
157   auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
158                         Const(scope, 1));
159   grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
160   return scope.status();
161 }
162 REGISTER_GRADIENT_OP("Fill", FillGrad);
163 
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)164 Status DiagGrad(const Scope& scope, const Operation& op,
165                 const std::vector<Output>& grad_inputs,
166                 std::vector<Output>* grad_outputs) {
167   grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
168   return scope.status();
169 }
170 REGISTER_GRADIENT_OP("Diag", DiagGrad);
171 
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)172 Status DiagPartGrad(const Scope& scope, const Operation& op,
173                     const std::vector<Output>& grad_inputs,
174                     std::vector<Output>* grad_outputs) {
175   grad_outputs->push_back(Diag(scope, grad_inputs[0]));
176   return scope.status();
177 }
178 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
179 
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)180 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
181                       const std::vector<Output>& grad_inputs,
182                       std::vector<Output>* grad_outputs) {
183   grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
184   return scope.status();
185 }
186 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
187 
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)188 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
189                           const std::vector<Output>& grad_inputs,
190                           std::vector<Output>* grad_outputs) {
191   auto num_lower = op.input(1);
192   auto num_upper = op.input(2);
193   grad_outputs->push_back(
194       MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
195   grad_outputs->push_back(NoGradient());
196   grad_outputs->push_back(NoGradient());
197   return scope.status();
198 }
199 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
200 
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)201 Status GatherNdGrad(const Scope& scope, const Operation& op,
202                     const std::vector<Output>& grad_inputs,
203                     std::vector<Output>* grad_outputs) {
204   auto ref = op.input(0);
205   auto ref_shape = Shape(scope, ref);
206   auto indices = op.input(1);
207   grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
208   grad_outputs->push_back(NoGradient());
209   return scope.status();
210 }
211 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
212 
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
214                          const std::vector<Output>& grad_inputs,
215                          std::vector<Output>* grad_outputs) {
216   string message;
217   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
218   string err_msg = strings::StrCat(
219       "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
220       message);
221   grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
222   return scope.status();
223 }
224 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
225 
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)226 Status ReshapeGrad(const Scope& scope, const Operation& op,
227                    const std::vector<Output>& grad_inputs,
228                    std::vector<Output>* grad_outputs) {
229   auto input_shape = Shape(scope, op.input(0));
230   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
231   grad_outputs->push_back(NoGradient());
232   return scope.status();
233 }
234 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
235 
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)236 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
237                       const std::vector<Output>& grad_inputs,
238                       std::vector<Output>* grad_outputs) {
239   auto input_shape = Shape(scope, op.input(0));
240   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
241   grad_outputs->push_back(NoGradient());
242   return scope.status();
243 }
244 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
245 
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)246 Status SqueezeGrad(const Scope& scope, const Operation& op,
247                    const std::vector<Output>& grad_inputs,
248                    std::vector<Output>* grad_outputs) {
249   auto input_shape = Shape(scope, op.input(0));
250   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
251   return scope.status();
252 }
253 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
254 
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)255 Status TransposeGrad(const Scope& scope, const Operation& op,
256                      const std::vector<Output>& grad_inputs,
257                      std::vector<Output>* grad_outputs) {
258   auto inverted_perm = InvertPermutation(scope, op.input(1));
259   grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
260   grad_outputs->push_back(NoGradient());
261   return scope.status();
262 }
263 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
264 
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)265 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
266                            const std::vector<Output>& grad_inputs,
267                            std::vector<Output>* grad_outputs) {
268   auto seq_lengths = op.input(1);
269   int batch_dim;
270   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
271   int seq_dim;
272   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
273   grad_outputs->push_back(
274       ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
275                       ReverseSequence::BatchDim(batch_dim)));
276   grad_outputs->push_back(NoGradient());
277   return scope.status();
278 }
279 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
280 
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)281 Status ReverseGrad(const Scope& scope, const Operation& op,
282                    const std::vector<Output>& grad_inputs,
283                    std::vector<Output>* grad_outputs) {
284   auto reverse_dims = op.input(1);
285   grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
286   grad_outputs->push_back(NoGradient());
287   return scope.status();
288 }
289 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
290 
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)291 Status ScatterNdGrad(const Scope& scope, const Operation& op,
292                      const std::vector<Output>& grad_inputs,
293                      std::vector<Output>* grad_outputs) {
294   auto indices = op.input(0);
295   grad_outputs->push_back(NoGradient());
296   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
297   grad_outputs->push_back(NoGradient());
298   return scope.status();
299 }
300 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
301 
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)302 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
303                                    const std::vector<Output>& grad_inputs,
304                                    std::vector<Output>* grad_outputs) {
305   auto indices = op.input(1);
306   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
307   grad_outputs->push_back(NoGradient());
308   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
309   return scope.status();
310 }
311 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
312 
313 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)314 Status PadGrad(const Scope& scope, const Operation& op,
315                const std::vector<Output>& grad_inputs,
316                std::vector<Output>* grad_outputs) {
317   auto x = op.input(0);
318   auto a = op.input(1);  // [Rank(x), 2]
319   // Takes a slice of a. The 1st column. [Rank(x), 1].
320   auto size = Stack(scope, {Rank(scope, x), 1});
321   auto pad_before = Slice(scope, a, {0, 0}, size);
322   // Make it a 1-D tensor.
323   auto begin = Reshape(scope, pad_before, {-1});
324   grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
325   grad_outputs->push_back(NoGradient());
326   // PadV2 adds a "constant_values" input.
327   if (IsPadV2) {
328     grad_outputs->push_back(NoGradient());
329   }
330   return scope.status();
331 }
332 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
333 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
334 
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)335 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
336                         const std::vector<Output>& grad_inputs,
337                         std::vector<Output>* grad_outputs) {
338   int block_size;
339   TF_RETURN_IF_ERROR(
340       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
341   grad_outputs->push_back(
342       BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
343   grad_outputs->push_back(NoGradient());
344   return scope.status();
345 }
346 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
347 
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)348 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
349                           const std::vector<Output>& grad_inputs,
350                           std::vector<Output>* grad_outputs) {
351   grad_outputs->push_back(
352       BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
353   grad_outputs->push_back(NoGradient());
354   grad_outputs->push_back(NoGradient());
355   return scope.status();
356 }
357 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
358 
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)359 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
360                         const std::vector<Output>& grad_inputs,
361                         std::vector<Output>* grad_outputs) {
362   int block_size;
363   TF_RETURN_IF_ERROR(
364       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
365   grad_outputs->push_back(
366       SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
367   grad_outputs->push_back(NoGradient());
368   return scope.status();
369 }
370 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
371 
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)372 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
373                           const std::vector<Output>& grad_inputs,
374                           std::vector<Output>* grad_outputs) {
375   grad_outputs->push_back(
376       SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
377   grad_outputs->push_back(NoGradient());
378   grad_outputs->push_back(NoGradient());
379   return scope.status();
380 }
381 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
382 
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)383 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
384                         const std::vector<Output>& grad_inputs,
385                         std::vector<Output>* grad_outputs) {
386   int block_size;
387   TF_RETURN_IF_ERROR(
388       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
389   grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
390   return scope.status();
391 }
392 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
393 
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)394 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
395                         const std::vector<Output>& grad_inputs,
396                         std::vector<Output>* grad_outputs) {
397   int block_size;
398   TF_RETURN_IF_ERROR(
399       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
400   grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
401   return scope.status();
402 }
403 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
404 
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)405 Status MirrorPadGrad(const Scope& scope, const Operation& op,
406                      const std::vector<Output>& grad_inputs,
407                      std::vector<Output>* grad_outputs) {
408   string mode;
409   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
410   grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
411       scope, grad_inputs[0], op.input(1), mode));
412   grad_outputs->push_back(NoGradient());
413   return scope.status();
414 }
415 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
416 
417 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)418 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
419                          const std::vector<Output>& grad_inputs,
420                          std::vector<Output>* grad_outputs) {
421   string mode;
422   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
423   grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
424   grad_outputs->push_back(NoGradient());
425   return scope.status();
426 }
427 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
428 
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)429 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
430                               const std::vector<Output>& grad_inputs,
431                               std::vector<Output>* grad_outputs) {
432   Input x = Shape(scope, op.input(0));
433   Input begin = op.input(1);
434   Input end = op.input(2);
435   Input strides = op.input(3);
436   int64_t begin_mask;
437   int64_t end_mask;
438   int64_t ellipsis_mask;
439   int64_t new_axis_mask;
440   int64_t shrink_axis_mask;
441   TF_RETURN_IF_ERROR(
442       GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
443   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
444   TF_RETURN_IF_ERROR(
445       GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
446   TF_RETURN_IF_ERROR(
447       GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
448   TF_RETURN_IF_ERROR(
449       GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
450   grad_outputs->push_back(
451       StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
452                        StridedSliceGrad::BeginMask(begin_mask)
453                            .EndMask(end_mask)
454                            .EllipsisMask(ellipsis_mask)
455                            .NewAxisMask(new_axis_mask)
456                            .ShrinkAxisMask(shrink_axis_mask)));
457   // No gradients returned for begin, end and strides
458   grad_outputs->push_back(NoGradient());
459   grad_outputs->push_back(NoGradient());
460   grad_outputs->push_back(NoGradient());
461   return scope.status();
462 }
463 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
464 
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)465 Status SliceGrad(const Scope& scope, const Operation& op,
466                  const std::vector<Output>& grad_inputs,
467                  std::vector<Output>* grad_outputs) {
468   // Propagate the incoming gradient along all the selected values,
469   // and zero everywhere else. Use the Pad operator for this.
470   //
471   // First create an Nx2 padding where N is the number of input
472   // dimensions. The first column is the number of prepended zeros
473   // for each dimension, and the second column is the number of
474   // appended zeros.
475   //
476   // The first column is just the begin vector.
477   // The second column is the shape of the input element-wise
478   // subtracted by begin+size
479 
480   // Running example:
481   // input.shape = [3, 5, 3]
482   // begin = [1, 2, 1], size = [1, 3, 2]
483   Input input = op.input(0);
484   Input begin = op.input(1);
485   // input_rank = 3
486   auto input_rank = Rank(scope, input);
487   // slice_size = [1, 3, 2]
488   auto slice_size = Shape(scope, op.output(0));
489   // padding_shape = [3, 1]
490   auto padding_shape = Stack(scope, {input_rank, 1});
491   // before_padding = [[1]
492   //                   [2]
493   //                   [1]]
494   Input before_padding = Reshape(scope, begin, padding_shape);
495   // after_padding_sizes = shape(input) - slice_size - begin
496   //                     = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
497   //                     = [1, 0, 0]
498   auto after_padding_sizes =
499       Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
500   // after_padding = [[1]
501   //                  [0]
502   //                  [0]]
503   Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
504   // paddings = [[1 1]
505   //             [2 0]
506   //             [1 0]]
507   auto paddings =
508       Concat(scope, {before_padding, after_padding}, Const(scope, 1));
509   grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
510   // Nothing propagated for "begin" and "size" inputs
511   grad_outputs->push_back(NoGradient());
512   grad_outputs->push_back(NoGradient());
513   return scope.status();
514 }
515 REGISTER_GRADIENT_OP("Slice", SliceGrad);
516 
ConcatGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs,int start_value_index,int end_value_index,int dim_index)517 Status ConcatGradHelper(const Scope& scope, const Operation& op,
518                         const std::vector<Output>& grad_inputs,
519                         std::vector<Output>* grad_outputs,
520                         int start_value_index, int end_value_index,
521                         int dim_index) {
522   if (end_value_index >= op.num_inputs()) {
523     return errors::Internal("Invalid input index");
524   }
525   std::vector<Output> inputs;
526   for (int i = start_value_index; i < end_value_index; ++i) {
527     inputs.push_back(op.input(i));
528   }
529 
530   auto shapes = ShapeN(scope, inputs);
531   const auto unique_name = scope.GetUniqueNameForOp("ConcatOffset");
532   auto builder =
533       ::tensorflow::NodeBuilder(unique_name, "ConcatOffset")
534           .Input(::tensorflow::ops::AsNodeOut(scope, op.input(dim_index)))
535           .Input(::tensorflow::ops::AsNodeOutList(scope, shapes.output));
536   scope.UpdateBuilder(&builder);
537   ::tensorflow::Node* concat_offset_node;
538   scope.UpdateStatus(builder.Finalize(scope.graph(), &concat_offset_node));
539   scope.UpdateStatus(scope.DoShapeInference(concat_offset_node));
540   if (concat_offset_node->num_outputs() != inputs.size()) {
541     return errors::Internal("ConcatOffset has invalid output count");
542   }
543   if (grad_inputs.size() != 1) {
544     return errors::InvalidArgument("Concat grad should have 1 input");
545   }
546 
547   // For each dx[i], we take a slice of dy. The offset and size of the
548   // slice is given by offset[i] and shape[i].
549   const Output& dy = grad_inputs[0];
550   for (int i = 0; i < inputs.size(); ++i) {
551     grad_outputs->push_back(
552         Slice(scope, dy, Output(concat_offset_node, i), shapes.output[i]));
553   }
554 
555   // Insert a NoGradient for the axis.
556   grad_outputs->insert(grad_outputs->begin() + dim_index, NoGradient());
557   return scope.status();
558 }
559 
ConcatV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)560 Status ConcatV2Grad(const Scope& scope, const Operation& op,
561                     const std::vector<Output>& grad_inputs,
562                     std::vector<Output>* grad_outputs) {
563   return ConcatGradHelper(scope, op, grad_inputs, grad_outputs,
564                           /*start_value_index=*/0,
565                           /*end_value_index=*/op.num_inputs() - 1,
566                           /*dim+index=*/op.num_inputs() - 1);
567 }
568 
569 REGISTER_GRADIENT_OP("ConcatV2", ConcatV2Grad);
570 
BroadcastToGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)571 Status BroadcastToGrad(const Scope& scope, const Operation& op,
572                        const std::vector<Output>& grad_inputs,
573                        std::vector<Output>* grad_outputs) {
574   if (grad_inputs.size() != 1) {
575     return errors::InvalidArgument("BroadcastTo grad should have 1 grad input");
576   }
577   if (op.num_inputs() != 2) {
578     return errors::InvalidArgument("BroadcastTo requires 2 inputs");
579   }
580 
581   auto x_shape = Shape(scope, op.input(0));
582   auto args = internal::BroadcastGradientArgs(scope, x_shape, op.input(1));
583   auto sum_gx = Sum(scope, grad_inputs[0], args.r0);
584   grad_outputs->push_back(Reshape(scope, sum_gx, x_shape));
585   grad_outputs->push_back(NoGradient());
586   return scope.status();
587 }
588 
589 REGISTER_GRADIENT_OP("BroadcastTo", BroadcastToGrad);
590 
TileGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)591 Status TileGrad(const Scope& scope, const Operation& op,
592                 const std::vector<Output>& grad_inputs,
593                 std::vector<Output>* grad_outputs) {
594   if (op.num_inputs() != 2) {
595     return errors::InvalidArgument("Tile requires 2 inputs");
596   }
597   if (grad_inputs.size() != 1) {
598     return errors::InvalidArgument("Tile grad requires 1 grad input");
599   }
600 
601   Shape::Attrs shape_attrs;
602   shape_attrs.out_type_ = op.input_type(1);
603   auto input_shape = Shape(scope, op.input(0), shape_attrs);
604   // We interleave multiples and input_shape to get split_shape,
605   // reshape grad to split_shape, and reduce along all even
606   // dimensions (the tiled dimensions) to get the result
607   // with shape input_shape.  For example
608   //   input_shape = [20, 30, 40]
609   //   multiples = [2, 3, 4]
610   //   split_shape = [2, 20, 3, 30, 4, 40]
611   //   axes = [0, 2, 4]
612   auto stack = Stack(scope, {op.input(1), input_shape.output});
613   auto perm = Range(scope, Sub(scope, Rank(scope, stack), 1), -1, -1);
614   auto split_shape = Reshape(scope, Transpose(scope, stack, perm), {-1});
615   auto axes = Range(scope, Const(scope, 0), Size(scope, split_shape.output), 2);
616   auto input_grad = ReduceSum(
617       scope, Reshape(scope, grad_inputs[0], split_shape.output), axes.output);
618   grad_outputs->push_back(input_grad.output);
619   grad_outputs->push_back(NoGradient());
620   return scope.status();
621 }
622 REGISTER_GRADIENT_OP("Tile", TileGrad);
623 
624 // Create a constant of the provided d_type;
ConstHelper(const Scope & scope,int value,DataType d_type)625 Output ConstHelper(const Scope& scope, int value, DataType d_type) {
626   return Cast(scope, Const(scope, value), d_type);
627 }
628 
629 // Adds the batch offsets to the given indices and returns the results.
GetBatchIndices(const Scope & scope,const Output & params_shape,const Output & indices,int batch_dims)630 Output GetBatchIndices(const Scope& scope, const Output& params_shape,
631                        const Output& indices, int batch_dims) {
632   Output batch_indices = indices;
633   auto indices_ndims = Rank(scope, indices);
634   auto casted_params_shape = Cast(scope, params_shape, indices.type());
635   Output accum_dim_value = ConstHelper(scope, 1, indices.type());
636   for (int dim = batch_dims; dim > 0; dim--) {
637     Output dim_value = Slice(scope, casted_params_shape, {dim - 1}, {1});
638     accum_dim_value = Multiply(scope, accum_dim_value,
639                                Slice(scope, casted_params_shape, {dim}, {1}));
640     auto start = ConstHelper(scope, 0, indices.type());
641     auto step = ConstHelper(scope, 1, indices.type());
642     Output dim_indices = Range(scope, start, Squeeze(scope, dim_value), step);
643     dim_indices = Multiply(scope, dim_indices, accum_dim_value);
644     auto one = Cast(scope, Const(scope, {1}), indices.type());
645     auto dim_shape = Concat(
646         scope,
647         {Output(Tile(scope, one, Const(scope, {dim - 1}))), dim_value,
648          Output(Tile(scope, one,
649                      ExpandDims(scope, Sub(scope, indices_ndims, dim), 0)))},
650         /*axis=*/0);
651     batch_indices =
652         Add(scope, batch_indices, Reshape(scope, dim_indices, dim_shape));
653   }
654 
655   return batch_indices;
656 }
657 
BatchGatherGrad(const Scope & scope,Output params_shape,Output values,Output indices,int batch_dims,Output gather_dim_size)658 Output BatchGatherGrad(const Scope& scope, Output params_shape, Output values,
659                        Output indices, int batch_dims, Output gather_dim_size) {
660   // Axis is the first non-batch dimension.
661   auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
662   Output outer_shape, flat_values_shape;
663   if (batch_dims != 0) {
664     auto values_shape = Shape(scope, values);
665     // Add the batch offsets to indices and flatten the batch dimensions.
666     outer_shape = Slice(scope, values_shape, {0}, {batch_dims});
667     auto inner_shape =
668         Slice(scope, Slice(scope, values_shape, {batch_dims}, {-1}), {1}, {-1});
669     auto batch_size = Prod(scope, outer_shape, /*axis=*/0);
670     flat_values_shape = Concat(scope, {{-1}, inner_shape}, /*axis=*/0);
671     gather_dim_size = Multiply(scope, gather_dim_size, batch_size);
672     indices = GetBatchIndices(scope, params_shape, indices, batch_dims);
673     values = Reshape(scope, values, flat_values_shape);
674   }
675 
676   indices = Reshape(scope, indices, indices_size);
677   Output params_grad =
678       UnsortedSegmentSum(scope, values, indices, gather_dim_size);
679 
680   if (batch_dims != 0) {
681     // Put back the batch dimensions.
682     params_grad = Reshape(scope, params_grad, params_shape);
683   }
684   return params_grad;
685 }
686 
GatherV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)687 Status GatherV2Grad(const Scope& scope, const Operation& op,
688                     const std::vector<Output>& grad_inputs,
689                     std::vector<Output>* grad_outputs) {
690   if (op.num_inputs() != 3) {
691     return errors::InvalidArgument("Gather requires 3 inputs");
692   }
693   if (grad_inputs.size() != 1) {
694     return errors::InvalidArgument("Gather grad requires 1 grad input");
695   }
696 
697   // params can be large, so colocate the shape calculation with it.
698   // params can be very large for sparse model, array_ops.shape raises
699   // exception on the Windows platform when any dimension is larger than
700   // int32. params_shape is not used in optimizer apply_sparse gradients,
701   // so it's fine to convert it back to int32 regardless of truncation.
702   auto params = op.input(0);
703   auto colocate_scope = scope.ColocateWith(params);
704   Shape::Attrs shape_attrs;
705   shape_attrs.out_type_ = DT_INT64;
706   auto params_shape64 = Shape(colocate_scope, params, shape_attrs);
707   Output params_shape = Cast(colocate_scope, params_shape64, DT_INT32);
708 
709   auto indices = op.input(1);
710   auto indices_size = ExpandDims(scope, Size(scope, indices), 0);
711   auto axis = op.input(2);
712   auto axis_expand = ExpandDims(scope, axis, 0);
713 
714   int batch_dims;
715   TF_RETURN_IF_ERROR(
716       GetNodeAttr(op.node()->attrs(), "batch_dims", &batch_dims));
717   if (batch_dims < 0) {
718     // TODO(bdodson): Figure out if we can find the param rank here, like the
719     // python implementation does.
720     return errors::InvalidArgument(
721         "C++ GatherV2 gradient does not support negative batch_dims.");
722   }
723 
724   // Handle axis by transposing the axis dimension to be the first non-batch
725   // dimension, compute the gradient and transpose the result back.
726   auto outer_shape = Slice(scope, params_shape, {0}, axis_expand);
727   auto inner_shape =
728       Slice(scope, Slice(scope, params_shape, axis_expand, {-1}), {1}, {-1});
729   auto values_shape = Concat(scope, {outer_shape, {-1}, inner_shape}, 0);
730   auto values_dims = Size(scope, values_shape);
731   auto axis_dims = Size(scope, outer_shape);
732 
733   Output outer_batches_indices = Range(scope, 0, batch_dims, /*delta=*/1);
734   Output batch_axis_indices = Range(scope, batch_dims, axis_dims, /*delta=*/1);
735   Output inner_axes_indices =
736       Range(scope, Add(scope, axis_dims, 1), values_dims, /*delta=*/1);
737   Output axis_dims_expand = ExpandDims(scope, axis_dims, 0);
738 
739   auto values = Reshape(scope, grad_inputs[0], values_shape);
740 
741   // Move values[axis] up to values[batch_dims]
742   Output transpose_dims = Concat(scope,
743                                  {outer_batches_indices, axis_dims_expand,
744                                   batch_axis_indices, inner_axes_indices},
745                                  0);
746   auto values_transpose = Transpose(scope, values, transpose_dims);
747   Output gather_dim_size =
748       Squeeze(scope, Slice(scope, params_shape, axis_expand, {1}));
749   params_shape = Gather(scope, params_shape, transpose_dims);
750 
751   auto params_grad = BatchGatherGrad(scope, params_shape, values_transpose,
752                                      indices, batch_dims, gather_dim_size);
753 
754   // Inverts the above transpose by moving dimension batch_dims back to its
755   // original position.
756   Output invert_transpose_dims = Concat(scope,
757                                         {outer_batches_indices,
758                                          Add(scope, batch_axis_indices, 1),
759                                          {batch_dims},
760                                          inner_axes_indices},
761                                         0);
762 
763   params_grad = Transpose(scope, params_grad, invert_transpose_dims);
764 
765   grad_outputs->push_back(params_grad);
766   grad_outputs->push_back(NoGradient());
767   grad_outputs->push_back(NoGradient());
768   return scope.status();
769 }
770 
771 REGISTER_GRADIENT_OP("GatherV2", GatherV2Grad);
772 
773 }  // anonymous namespace
774 }  // namespace ops
775 }  // namespace tensorflow
776