1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "tensorflow/cc/ops/array_ops_internal.h"
19 #include "tensorflow/cc/ops/standard_ops.h"
20 #include "tensorflow/core/lib/strings/strcat.h"
21
22 #include "tensorflow/cc/framework/grad_op_registry.h"
23 #include "tensorflow/cc/framework/gradients.h"
24
25 namespace tensorflow {
26 namespace ops {
27 namespace {
28
29 REGISTER_NO_GRADIENT_OP("Const");
30 REGISTER_NO_GRADIENT_OP("StopGradient");
31 REGISTER_NO_GRADIENT_OP("ConcatOffset");
32 REGISTER_NO_GRADIENT_OP("EditDistance");
33 REGISTER_NO_GRADIENT_OP("ZerosLike");
34 REGISTER_NO_GRADIENT_OP("InvertPermutation");
35 REGISTER_NO_GRADIENT_OP("Shape");
36 REGISTER_NO_GRADIENT_OP("ShapeN");
37 REGISTER_NO_GRADIENT_OP("Rank");
38 REGISTER_NO_GRADIENT_OP("Size");
39 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
40 REGISTER_NO_GRADIENT_OP("OneHot");
41
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)42 Status PackGrad(const Scope& scope, const Operation& op,
43 const std::vector<Output>& grad_inputs,
44 std::vector<Output>* grad_outputs) {
45 int N;
46 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
47 int axis;
48 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
49
50 grad_outputs->reserve(N);
51 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
52 for (const Output& o : grad_op.output) {
53 grad_outputs->emplace_back(o);
54 }
55 return scope.status();
56 }
57 REGISTER_GRADIENT_OP("Pack", PackGrad);
58
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)59 Status UnpackGrad(const Scope& scope, const Operation& op,
60 const std::vector<Output>& grad_inputs,
61 std::vector<Output>* grad_outputs) {
62 int axis;
63 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
64 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
65 return scope.status();
66 }
67 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
68
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)69 Status IdentityGrad(const Scope& scope, const Operation& op,
70 const std::vector<Output>& grad_inputs,
71 std::vector<Output>* grad_outputs) {
72 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
73 return scope.status();
74 }
75 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
76
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)77 Status RefIdentityGrad(const Scope& scope, const Operation& op,
78 const std::vector<Output>& grad_inputs,
79 std::vector<Output>* grad_outputs) {
80 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
81 return scope.status();
82 }
83 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
84
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)85 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
86 const std::vector<Output>& grad_inputs,
87 std::vector<Output>* grad_outputs) {
88 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
89 return scope.status();
90 }
91 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
92
QuantizeAndDequantizeV2Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)93 Status QuantizeAndDequantizeV2Grad(const Scope& scope, const Operation& op,
94 const std::vector<Output>& grad_inputs,
95 std::vector<Output>* grad_outputs) {
96 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
97 grad_outputs->push_back(NoGradient());
98 grad_outputs->push_back(NoGradient());
99 return scope.status();
100 }
101 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV2", QuantizeAndDequantizeV2Grad);
102
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)103 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
104 const std::vector<Output>& grad_inputs,
105 std::vector<Output>* grad_outputs) {
106 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
107 grad_outputs->push_back(NoGradient());
108 grad_outputs->push_back(NoGradient());
109 grad_outputs->push_back(NoGradient());
110 return scope.status();
111 }
112 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
113
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)114 Status SplitGrad(const Scope& scope, const Operation& op,
115 const std::vector<Output>& grad_inputs,
116 std::vector<Output>* grad_outputs) {
117 grad_outputs->push_back(NoGradient());
118 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
119 return scope.status();
120 }
121 REGISTER_GRADIENT_OP("Split", SplitGrad);
122
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status FillGrad(const Scope& scope, const Operation& op,
124 const std::vector<Output>& grad_inputs,
125 std::vector<Output>* grad_outputs) {
126 // y = fill(fill_shape, x)
127 // No gradient returned for the fill_shape argument.
128 grad_outputs->push_back(NoGradient());
129 // The gradient for x (which must be a scalar) is just the sum of
130 // all the gradients from the shape it fills.
131 // We use ReduceSum to implement this, which needs an argument providing
132 // the indices of all the dimensions of the incoming gradient.
133 // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
134 auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
135 Const(scope, 1));
136 grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
137 return scope.status();
138 }
139 REGISTER_GRADIENT_OP("Fill", FillGrad);
140
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)141 Status DiagGrad(const Scope& scope, const Operation& op,
142 const std::vector<Output>& grad_inputs,
143 std::vector<Output>* grad_outputs) {
144 grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
145 return scope.status();
146 }
147 REGISTER_GRADIENT_OP("Diag", DiagGrad);
148
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)149 Status DiagPartGrad(const Scope& scope, const Operation& op,
150 const std::vector<Output>& grad_inputs,
151 std::vector<Output>* grad_outputs) {
152 grad_outputs->push_back(Diag(scope, grad_inputs[0]));
153 return scope.status();
154 }
155 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
156
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)157 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
158 const std::vector<Output>& grad_inputs,
159 std::vector<Output>* grad_outputs) {
160 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
161 return scope.status();
162 }
163 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
164
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)165 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
166 const std::vector<Output>& grad_inputs,
167 std::vector<Output>* grad_outputs) {
168 auto num_lower = op.input(1);
169 auto num_upper = op.input(2);
170 grad_outputs->push_back(
171 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
172 grad_outputs->push_back(NoGradient());
173 grad_outputs->push_back(NoGradient());
174 return scope.status();
175 }
176 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
177
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)178 Status GatherNdGrad(const Scope& scope, const Operation& op,
179 const std::vector<Output>& grad_inputs,
180 std::vector<Output>* grad_outputs) {
181 auto ref = op.input(0);
182 auto ref_shape = Shape(scope, ref);
183 auto indices = op.input(1);
184 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
185 grad_outputs->push_back(NoGradient());
186 return scope.status();
187 }
188 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
189
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)190 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
191 const std::vector<Output>& grad_inputs,
192 std::vector<Output>* grad_outputs) {
193 string message;
194 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
195 string err_msg = strings::StrCat(
196 "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
197 message);
198 grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
199 return scope.status();
200 }
201 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
202
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)203 Status ReshapeGrad(const Scope& scope, const Operation& op,
204 const std::vector<Output>& grad_inputs,
205 std::vector<Output>* grad_outputs) {
206 auto input_shape = Shape(scope, op.input(0));
207 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
208 grad_outputs->push_back(NoGradient());
209 return scope.status();
210 }
211 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
212
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)213 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
214 const std::vector<Output>& grad_inputs,
215 std::vector<Output>* grad_outputs) {
216 auto input_shape = Shape(scope, op.input(0));
217 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
218 grad_outputs->push_back(NoGradient());
219 return scope.status();
220 }
221 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
222
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)223 Status SqueezeGrad(const Scope& scope, const Operation& op,
224 const std::vector<Output>& grad_inputs,
225 std::vector<Output>* grad_outputs) {
226 auto input_shape = Shape(scope, op.input(0));
227 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
228 return scope.status();
229 }
230 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
231
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)232 Status TransposeGrad(const Scope& scope, const Operation& op,
233 const std::vector<Output>& grad_inputs,
234 std::vector<Output>* grad_outputs) {
235 auto inverted_perm = InvertPermutation(scope, op.input(1));
236 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
237 grad_outputs->push_back(NoGradient());
238 return scope.status();
239 }
240 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
241
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)242 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
243 const std::vector<Output>& grad_inputs,
244 std::vector<Output>* grad_outputs) {
245 auto seq_lengths = op.input(1);
246 int batch_dim;
247 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
248 int seq_dim;
249 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
250 grad_outputs->push_back(
251 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
252 ReverseSequence::BatchDim(batch_dim)));
253 grad_outputs->push_back(NoGradient());
254 return scope.status();
255 }
256 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
257
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)258 Status ReverseGrad(const Scope& scope, const Operation& op,
259 const std::vector<Output>& grad_inputs,
260 std::vector<Output>* grad_outputs) {
261 auto reverse_dims = op.input(1);
262 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
263 grad_outputs->push_back(NoGradient());
264 return scope.status();
265 }
266 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
267
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)268 Status ScatterNdGrad(const Scope& scope, const Operation& op,
269 const std::vector<Output>& grad_inputs,
270 std::vector<Output>* grad_outputs) {
271 auto indices = op.input(0);
272 grad_outputs->push_back(NoGradient());
273 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
274 grad_outputs->push_back(NoGradient());
275 return scope.status();
276 }
277 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
278
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)279 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
280 const std::vector<Output>& grad_inputs,
281 std::vector<Output>* grad_outputs) {
282 auto indices = op.input(1);
283 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
284 grad_outputs->push_back(NoGradient());
285 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
286 return scope.status();
287 }
288 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
289
290 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)291 Status PadGrad(const Scope& scope, const Operation& op,
292 const std::vector<Output>& grad_inputs,
293 std::vector<Output>* grad_outputs) {
294 auto x = op.input(0);
295 auto a = op.input(1); // [Rank(x), 2]
296 // Takes a slice of a. The 1st column. [Rank(x), 1].
297 auto size = Stack(scope, {Rank(scope, x), 1});
298 auto pad_before = Slice(scope, a, {0, 0}, size);
299 // Make it a 1-D tensor.
300 auto begin = Reshape(scope, pad_before, {-1});
301 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
302 grad_outputs->push_back(NoGradient());
303 // PadV2 adds a "constant_values" input.
304 if (IsPadV2) {
305 grad_outputs->push_back(NoGradient());
306 }
307 return scope.status();
308 }
309 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
310 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
311
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)312 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
313 const std::vector<Output>& grad_inputs,
314 std::vector<Output>* grad_outputs) {
315 int block_size;
316 TF_RETURN_IF_ERROR(
317 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
318 grad_outputs->push_back(
319 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
320 grad_outputs->push_back(NoGradient());
321 return scope.status();
322 }
323 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
324
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)325 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
326 const std::vector<Output>& grad_inputs,
327 std::vector<Output>* grad_outputs) {
328 grad_outputs->push_back(
329 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
330 grad_outputs->push_back(NoGradient());
331 grad_outputs->push_back(NoGradient());
332 return scope.status();
333 }
334 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
335
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)336 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
337 const std::vector<Output>& grad_inputs,
338 std::vector<Output>* grad_outputs) {
339 int block_size;
340 TF_RETURN_IF_ERROR(
341 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
342 grad_outputs->push_back(
343 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
344 grad_outputs->push_back(NoGradient());
345 return scope.status();
346 }
347 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
348
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)349 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
350 const std::vector<Output>& grad_inputs,
351 std::vector<Output>* grad_outputs) {
352 grad_outputs->push_back(
353 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
354 grad_outputs->push_back(NoGradient());
355 grad_outputs->push_back(NoGradient());
356 return scope.status();
357 }
358 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
359
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)360 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
361 const std::vector<Output>& grad_inputs,
362 std::vector<Output>* grad_outputs) {
363 int block_size;
364 TF_RETURN_IF_ERROR(
365 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
366 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
367 return scope.status();
368 }
369 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
370
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)371 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
372 const std::vector<Output>& grad_inputs,
373 std::vector<Output>* grad_outputs) {
374 int block_size;
375 TF_RETURN_IF_ERROR(
376 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
377 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
378 return scope.status();
379 }
380 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
381
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)382 Status MirrorPadGrad(const Scope& scope, const Operation& op,
383 const std::vector<Output>& grad_inputs,
384 std::vector<Output>* grad_outputs) {
385 string mode;
386 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
387 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
388 scope, grad_inputs[0], op.input(1), mode));
389 grad_outputs->push_back(NoGradient());
390 return scope.status();
391 }
392 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
393
394 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)395 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
396 const std::vector<Output>& grad_inputs,
397 std::vector<Output>* grad_outputs) {
398 string mode;
399 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
400 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
401 grad_outputs->push_back(NoGradient());
402 return scope.status();
403 }
404 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
405
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)406 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
407 const std::vector<Output>& grad_inputs,
408 std::vector<Output>* grad_outputs) {
409 Input x = Shape(scope, op.input(0));
410 Input begin = op.input(1);
411 Input end = op.input(2);
412 Input strides = op.input(3);
413 int64 begin_mask;
414 int64 end_mask;
415 int64 ellipsis_mask;
416 int64 new_axis_mask;
417 int64 shrink_axis_mask;
418 TF_RETURN_IF_ERROR(
419 GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
420 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
421 TF_RETURN_IF_ERROR(
422 GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
423 TF_RETURN_IF_ERROR(
424 GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
425 TF_RETURN_IF_ERROR(
426 GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
427 grad_outputs->push_back(
428 StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
429 StridedSliceGrad::BeginMask(begin_mask)
430 .EndMask(end_mask)
431 .EllipsisMask(ellipsis_mask)
432 .NewAxisMask(new_axis_mask)
433 .ShrinkAxisMask(shrink_axis_mask)));
434 // No gradients returned for begin, end and strides
435 grad_outputs->push_back(NoGradient());
436 grad_outputs->push_back(NoGradient());
437 grad_outputs->push_back(NoGradient());
438 return scope.status();
439 }
440 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
441
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)442 Status SliceGrad(const Scope& scope, const Operation& op,
443 const std::vector<Output>& grad_inputs,
444 std::vector<Output>* grad_outputs) {
445 // Propagate the incoming gradient along all the selected values,
446 // and zero everywhere else. Use the Pad operator for this.
447 //
448 // First create an Nx2 padding where N is the number of input
449 // dimensions. The first column is the number of prepended zeros
450 // for each dimension, and the second column is the number of
451 // appended zeros.
452 //
453 // The first column is just the begin vector.
454 // The second column is the shape of the input element-wise
455 // subtracted by begin+size
456
457 // Running example:
458 // input.shape = [3, 5, 3]
459 // begin = [1, 2, 1], size = [1, 3, 2]
460 Input input = op.input(0);
461 Input begin = op.input(1);
462 // input_rank = 3
463 auto input_rank = Rank(scope, input);
464 // slice_size = [1, 3, 2]
465 auto slice_size = Shape(scope, op.output(0));
466 // padding_shape = [3, 1]
467 auto padding_shape = Stack(scope, {input_rank, 1});
468 // before_padding = [[1]
469 // [2]
470 // [1]]
471 Input before_padding = Reshape(scope, begin, padding_shape);
472 // after_padding_sizes = shape(input) - slice_size - begin
473 // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
474 // = [1, 0, 0]
475 auto after_padding_sizes =
476 Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
477 // after_padding = [[1]
478 // [0]
479 // [0]]
480 Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
481 // paddings = [[1 1]
482 // [2 0]
483 // [1 0]]
484 auto paddings =
485 Concat(scope, {before_padding, after_padding}, Const(scope, 1));
486 grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
487 // Nothing propagated for "begin" and "size" inputs
488 grad_outputs->push_back(NoGradient());
489 grad_outputs->push_back(NoGradient());
490 return scope.status();
491 }
492 REGISTER_GRADIENT_OP("Slice", SliceGrad);
493
494 } // anonymous namespace
495 } // namespace ops
496 } // namespace tensorflow
497