• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "tensorflow/cc/ops/array_ops_internal.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21 
22 #include "tensorflow/cc/framework/grad_op_registry.h"
23 #include "tensorflow/cc/framework/gradients.h"
24 
25 namespace tensorflow {
26 namespace ops {
27 namespace {
28 
29 REGISTER_NO_GRADIENT_OP("Const");
30 REGISTER_NO_GRADIENT_OP("StopGradient");
31 REGISTER_NO_GRADIENT_OP("ConcatOffset");
32 REGISTER_NO_GRADIENT_OP("EditDistance");
33 REGISTER_NO_GRADIENT_OP("ZerosLike");
34 REGISTER_NO_GRADIENT_OP("InvertPermutation");
35 REGISTER_NO_GRADIENT_OP("Shape");
36 REGISTER_NO_GRADIENT_OP("ShapeN");
37 REGISTER_NO_GRADIENT_OP("Rank");
38 REGISTER_NO_GRADIENT_OP("Size");
39 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
40 REGISTER_NO_GRADIENT_OP("OneHot");
41 
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)42 Status PackGrad(const Scope& scope, const Operation& op,
43                 const std::vector<Output>& grad_inputs,
44                 std::vector<Output>* grad_outputs) {
45   int N;
46   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
47   int axis;
48   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
49 
50   grad_outputs->reserve(N);
51   auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
52   for (const Output& o : grad_op.output) {
53     grad_outputs->emplace_back(o);
54   }
55   return scope.status();
56 }
57 REGISTER_GRADIENT_OP("Pack", PackGrad);
58 
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)59 Status UnpackGrad(const Scope& scope, const Operation& op,
60                   const std::vector<Output>& grad_inputs,
61                   std::vector<Output>* grad_outputs) {
62   int axis;
63   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
64   grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
65   return scope.status();
66 }
67 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
68 
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)69 Status IdentityGrad(const Scope& scope, const Operation& op,
70                     const std::vector<Output>& grad_inputs,
71                     std::vector<Output>* grad_outputs) {
72   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
73   return scope.status();
74 }
75 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
76 
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)77 Status RefIdentityGrad(const Scope& scope, const Operation& op,
78                        const std::vector<Output>& grad_inputs,
79                        std::vector<Output>* grad_outputs) {
80   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
81   return scope.status();
82 }
83 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
84 
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)85 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
86                                  const std::vector<Output>& grad_inputs,
87                                  std::vector<Output>* grad_outputs) {
88   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
89   return scope.status();
90 }
91 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
92 
QuantizeAndDequantizeV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)93 Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
94                                    const std::vector<Output>& grad_inputs,
95                                    std::vector<Output>* grad_outputs) {
96   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
97   grad_outputs->push_back(NoGradient());
98   grad_outputs->push_back(NoGradient());
99   return scope.status();
100 }
101 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
102 
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)103 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
104                                    const std::vector<Output>& grad_inputs,
105                                    std::vector<Output>* grad_outputs) {
106   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
107   grad_outputs->push_back(NoGradient());
108   grad_outputs->push_back(NoGradient());
109   grad_outputs->push_back(NoGradient());
110   return scope.status();
111 }
112 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
113 
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)114 Status SplitGrad(const Scope& scope, const Operation& op,
115                  const std::vector<Output>& grad_inputs,
116                  std::vector<Output>* grad_outputs) {
117   grad_outputs->push_back(NoGradient());
118   grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
119   return scope.status();
120 }
121 REGISTER_GRADIENT_OP("Split", SplitGrad);
122 
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status FillGrad(const Scope& scope, const Operation& op,
124                 const std::vector<Output>& grad_inputs,
125                 std::vector<Output>* grad_outputs) {
126   // y = fill(fill_shape, x)
127   // No gradient returned for the fill_shape argument.
128   grad_outputs->push_back(NoGradient());
129   // The gradient for x (which must be a scalar) is just the sum of
130   // all the gradients from the shape it fills.
131   // We use ReduceSum to implement this, which needs an argument providing
132   // the indices of all the dimensions of the incoming gradient.
133   // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
134   auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
135                         Const(scope, 1));
136   grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
137   return scope.status();
138 }
139 REGISTER_GRADIENT_OP("Fill", FillGrad);
140 
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)141 Status DiagGrad(const Scope& scope, const Operation& op,
142                 const std::vector<Output>& grad_inputs,
143                 std::vector<Output>* grad_outputs) {
144   grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
145   return scope.status();
146 }
147 REGISTER_GRADIENT_OP("Diag", DiagGrad);
148 
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)149 Status DiagPartGrad(const Scope& scope, const Operation& op,
150                     const std::vector<Output>& grad_inputs,
151                     std::vector<Output>* grad_outputs) {
152   grad_outputs->push_back(Diag(scope, grad_inputs[0]));
153   return scope.status();
154 }
155 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
156 
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)157 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
158                       const std::vector<Output>& grad_inputs,
159                       std::vector<Output>* grad_outputs) {
160   grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
161   return scope.status();
162 }
163 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
164 
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)165 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
166                           const std::vector<Output>& grad_inputs,
167                           std::vector<Output>* grad_outputs) {
168   auto num_lower = op.input(1);
169   auto num_upper = op.input(2);
170   grad_outputs->push_back(
171       MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
172   grad_outputs->push_back(NoGradient());
173   grad_outputs->push_back(NoGradient());
174   return scope.status();
175 }
176 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
177 
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)178 Status GatherNdGrad(const Scope& scope, const Operation& op,
179                     const std::vector<Output>& grad_inputs,
180                     std::vector<Output>* grad_outputs) {
181   auto ref = op.input(0);
182   auto ref_shape = Shape(scope, ref);
183   auto indices = op.input(1);
184   grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
185   grad_outputs->push_back(NoGradient());
186   return scope.status();
187 }
188 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
189 
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)190 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
191                          const std::vector<Output>& grad_inputs,
192                          std::vector<Output>* grad_outputs) {
193   string message;
194   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
195   string err_msg = strings::StrCat(
196       "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
197       message);
198   grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
199   return scope.status();
200 }
201 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
202 
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)203 Status ReshapeGrad(const Scope& scope, const Operation& op,
204                    const std::vector<Output>& grad_inputs,
205                    std::vector<Output>* grad_outputs) {
206   auto input_shape = Shape(scope, op.input(0));
207   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
208   grad_outputs->push_back(NoGradient());
209   return scope.status();
210 }
211 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
212 
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
214                       const std::vector<Output>& grad_inputs,
215                       std::vector<Output>* grad_outputs) {
216   auto input_shape = Shape(scope, op.input(0));
217   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
218   grad_outputs->push_back(NoGradient());
219   return scope.status();
220 }
221 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
222 
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)223 Status SqueezeGrad(const Scope& scope, const Operation& op,
224                    const std::vector<Output>& grad_inputs,
225                    std::vector<Output>* grad_outputs) {
226   auto input_shape = Shape(scope, op.input(0));
227   grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
228   return scope.status();
229 }
230 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
231 
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)232 Status TransposeGrad(const Scope& scope, const Operation& op,
233                      const std::vector<Output>& grad_inputs,
234                      std::vector<Output>* grad_outputs) {
235   auto inverted_perm = InvertPermutation(scope, op.input(1));
236   grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
237   grad_outputs->push_back(NoGradient());
238   return scope.status();
239 }
240 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
241 
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)242 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
243                            const std::vector<Output>& grad_inputs,
244                            std::vector<Output>* grad_outputs) {
245   auto seq_lengths = op.input(1);
246   int batch_dim;
247   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
248   int seq_dim;
249   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
250   grad_outputs->push_back(
251       ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
252                       ReverseSequence::BatchDim(batch_dim)));
253   grad_outputs->push_back(NoGradient());
254   return scope.status();
255 }
256 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
257 
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)258 Status ReverseGrad(const Scope& scope, const Operation& op,
259                    const std::vector<Output>& grad_inputs,
260                    std::vector<Output>* grad_outputs) {
261   auto reverse_dims = op.input(1);
262   grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
263   grad_outputs->push_back(NoGradient());
264   return scope.status();
265 }
266 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
267 
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)268 Status ScatterNdGrad(const Scope& scope, const Operation& op,
269                      const std::vector<Output>& grad_inputs,
270                      std::vector<Output>* grad_outputs) {
271   auto indices = op.input(0);
272   grad_outputs->push_back(NoGradient());
273   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
274   grad_outputs->push_back(NoGradient());
275   return scope.status();
276 }
277 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
278 
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)279 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
280                                    const std::vector<Output>& grad_inputs,
281                                    std::vector<Output>* grad_outputs) {
282   auto indices = op.input(1);
283   grad_outputs->push_back(Identity(scope, grad_inputs[0]));
284   grad_outputs->push_back(NoGradient());
285   grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
286   return scope.status();
287 }
288 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
289 
290 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)291 Status PadGrad(const Scope& scope, const Operation& op,
292                const std::vector<Output>& grad_inputs,
293                std::vector<Output>* grad_outputs) {
294   auto x = op.input(0);
295   auto a = op.input(1);  // [Rank(x), 2]
296   // Takes a slice of a. The 1st column. [Rank(x), 1].
297   auto size = Stack(scope, {Rank(scope, x), 1});
298   auto pad_before = Slice(scope, a, {0, 0}, size);
299   // Make it a 1-D tensor.
300   auto begin = Reshape(scope, pad_before, {-1});
301   grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
302   grad_outputs->push_back(NoGradient());
303   // PadV2 adds a "constant_values" input.
304   if (IsPadV2) {
305     grad_outputs->push_back(NoGradient());
306   }
307   return scope.status();
308 }
309 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
310 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
311 
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)312 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
313                         const std::vector<Output>& grad_inputs,
314                         std::vector<Output>* grad_outputs) {
315   int block_size;
316   TF_RETURN_IF_ERROR(
317       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
318   grad_outputs->push_back(
319       BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
320   grad_outputs->push_back(NoGradient());
321   return scope.status();
322 }
323 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
324 
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)325 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
326                           const std::vector<Output>& grad_inputs,
327                           std::vector<Output>* grad_outputs) {
328   grad_outputs->push_back(
329       BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
330   grad_outputs->push_back(NoGradient());
331   grad_outputs->push_back(NoGradient());
332   return scope.status();
333 }
334 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
335 
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)336 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
337                         const std::vector<Output>& grad_inputs,
338                         std::vector<Output>* grad_outputs) {
339   int block_size;
340   TF_RETURN_IF_ERROR(
341       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
342   grad_outputs->push_back(
343       SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
344   grad_outputs->push_back(NoGradient());
345   return scope.status();
346 }
347 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
348 
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)349 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
350                           const std::vector<Output>& grad_inputs,
351                           std::vector<Output>* grad_outputs) {
352   grad_outputs->push_back(
353       SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
354   grad_outputs->push_back(NoGradient());
355   grad_outputs->push_back(NoGradient());
356   return scope.status();
357 }
358 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
359 
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)360 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
361                         const std::vector<Output>& grad_inputs,
362                         std::vector<Output>* grad_outputs) {
363   int block_size;
364   TF_RETURN_IF_ERROR(
365       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
366   grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
367   return scope.status();
368 }
369 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
370 
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)371 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
372                         const std::vector<Output>& grad_inputs,
373                         std::vector<Output>* grad_outputs) {
374   int block_size;
375   TF_RETURN_IF_ERROR(
376       GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
377   grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
378   return scope.status();
379 }
380 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
381 
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)382 Status MirrorPadGrad(const Scope& scope, const Operation& op,
383                      const std::vector<Output>& grad_inputs,
384                      std::vector<Output>* grad_outputs) {
385   string mode;
386   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
387   grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
388       scope, grad_inputs[0], op.input(1), mode));
389   grad_outputs->push_back(NoGradient());
390   return scope.status();
391 }
392 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
393 
394 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)395 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
396                          const std::vector<Output>& grad_inputs,
397                          std::vector<Output>* grad_outputs) {
398   string mode;
399   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
400   grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
401   grad_outputs->push_back(NoGradient());
402   return scope.status();
403 }
404 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
405 
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)406 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
407                               const std::vector<Output>& grad_inputs,
408                               std::vector<Output>* grad_outputs) {
409   Input x = Shape(scope, op.input(0));
410   Input begin = op.input(1);
411   Input end = op.input(2);
412   Input strides = op.input(3);
413   int64 begin_mask;
414   int64 end_mask;
415   int64 ellipsis_mask;
416   int64 new_axis_mask;
417   int64 shrink_axis_mask;
418   TF_RETURN_IF_ERROR(
419       GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
420   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
421   TF_RETURN_IF_ERROR(
422       GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
423   TF_RETURN_IF_ERROR(
424       GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
425   TF_RETURN_IF_ERROR(
426       GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
427   grad_outputs->push_back(
428       StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
429                        StridedSliceGrad::BeginMask(begin_mask)
430                            .EndMask(end_mask)
431                            .EllipsisMask(ellipsis_mask)
432                            .NewAxisMask(new_axis_mask)
433                            .ShrinkAxisMask(shrink_axis_mask)));
434   // No gradients returned for begin, end and strides
435   grad_outputs->push_back(NoGradient());
436   grad_outputs->push_back(NoGradient());
437   grad_outputs->push_back(NoGradient());
438   return scope.status();
439 }
440 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
441 
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)442 Status SliceGrad(const Scope& scope, const Operation& op,
443                  const std::vector<Output>& grad_inputs,
444                  std::vector<Output>* grad_outputs) {
445   // Propagate the incoming gradient along all the selected values,
446   // and zero everywhere else. Use the Pad operator for this.
447   //
448   // First create an Nx2 padding where N is the number of input
449   // dimensions. The first column is the number of prepended zeros
450   // for each dimension, and the second column is the number of
451   // appended zeros.
452   //
453   // The first column is just the begin vector.
454   // The second column is the shape of the input element-wise
455   // subtracted by begin+size
456 
457   // Running example:
458   // input.shape = [3, 5, 3]
459   // begin = [1, 2, 1], size = [1, 3, 2]
460   Input input = op.input(0);
461   Input begin = op.input(1);
462   // input_rank = 3
463   auto input_rank = Rank(scope, input);
464   // slice_size = [1, 3, 2]
465   auto slice_size = Shape(scope, op.output(0));
466   // padding_shape = [3, 1]
467   auto padding_shape = Stack(scope, {input_rank, 1});
468   // before_padding = [[1]
469   //                   [2]
470   //                   [1]]
471   Input before_padding = Reshape(scope, begin, padding_shape);
472   // after_padding_sizes = shape(input) - slice_size - begin
473   //                     = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
474   //                     = [1, 0, 0]
475   auto after_padding_sizes =
476       Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
477   // after_padding = [[1]
478   //                  [0]
479   //                  [0]]
480   Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
481   // paddings = [[1 1]
482   //             [2 0]
483   //             [1 0]]
484   auto paddings =
485       Concat(scope, {before_padding, after_padding}, Const(scope, 1));
486   grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
487   // Nothing propagated for "begin" and "size" inputs
488   grad_outputs->push_back(NoGradient());
489   grad_outputs->push_back(NoGradient());
490   return scope.status();
491 }
492 REGISTER_GRADIENT_OP("Slice", SliceGrad);
493 
494 }  // anonymous namespace
495 }  // namespace ops
496 }  // namespace tensorflow
497