1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "tensorflow/cc/framework/grad_op_registry.h"
19 #include "tensorflow/cc/framework/gradients.h"
20 #include "tensorflow/cc/ops/array_ops_internal.h"
21 #include "tensorflow/cc/ops/standard_ops.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23
24 namespace tensorflow {
25 namespace ops {
26 namespace {
27
28 REGISTER_NO_GRADIENT_OP("Const");
29 REGISTER_NO_GRADIENT_OP("StopGradient");
30 REGISTER_NO_GRADIENT_OP("ConcatOffset");
31 REGISTER_NO_GRADIENT_OP("EditDistance");
32 REGISTER_NO_GRADIENT_OP("ZerosLike");
33 REGISTER_NO_GRADIENT_OP("InvertPermutation");
34 REGISTER_NO_GRADIENT_OP("Shape");
35 REGISTER_NO_GRADIENT_OP("ShapeN");
36 REGISTER_NO_GRADIENT_OP("Rank");
37 REGISTER_NO_GRADIENT_OP("Size");
38 REGISTER_NO_GRADIENT_OP("BroadcastGradientArgs");
39 REGISTER_NO_GRADIENT_OP("OneHot");
40
PackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)41 Status PackGrad(const Scope& scope, const Operation& op,
42 const std::vector<Output>& grad_inputs,
43 std::vector<Output>* grad_outputs) {
44 int N;
45 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "N", &N));
46 int axis;
47 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
48
49 grad_outputs->reserve(N);
50 auto grad_op = Unstack(scope, grad_inputs[0], N, Unstack::Axis(axis));
51 for (const Output& o : grad_op.output) {
52 grad_outputs->emplace_back(o);
53 }
54 return scope.status();
55 }
56 REGISTER_GRADIENT_OP("Pack", PackGrad);
57
UnpackGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)58 Status UnpackGrad(const Scope& scope, const Operation& op,
59 const std::vector<Output>& grad_inputs,
60 std::vector<Output>* grad_outputs) {
61 int axis;
62 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
63 grad_outputs->push_back(Stack(scope, grad_inputs, Stack::Axis(axis)));
64 return scope.status();
65 }
66 REGISTER_GRADIENT_OP("Unpack", UnpackGrad);
67
IdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)68 Status IdentityGrad(const Scope& scope, const Operation& op,
69 const std::vector<Output>& grad_inputs,
70 std::vector<Output>* grad_outputs) {
71 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
72 return scope.status();
73 }
74 REGISTER_GRADIENT_OP("Identity", IdentityGrad);
75
RefIdentityGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)76 Status RefIdentityGrad(const Scope& scope, const Operation& op,
77 const std::vector<Output>& grad_inputs,
78 std::vector<Output>* grad_outputs) {
79 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
80 return scope.status();
81 }
82 REGISTER_GRADIENT_OP("RefIdentity", RefIdentityGrad);
83
QuantizeAndDequantizeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)84 Status QuantizeAndDequantizeGrad(const Scope& scope, const Operation& op,
85 const std::vector<Output>& grad_inputs,
86 std::vector<Output>* grad_outputs) {
87 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
88 return scope.status();
89 }
90 REGISTER_GRADIENT_OP("QuantizeAndDequantize", QuantizeAndDequantizeGrad);
91
QuantizeAndDequantizeV4GradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)92 Status QuantizeAndDequantizeV4GradHelper(const Scope& scope,
93 const Operation& op,
94 const std::vector<Output>& grad_inputs,
95 std::vector<Output>* grad_outputs) {
96 Input input = Shape(scope, op.input(0));
97 Input input_min = op.input(1);
98 Input input_max = op.input(2);
99 int64 axis;
100 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "axis", &axis));
101 auto qdq_v4_grad = QuantizeAndDequantizeV4Grad(
102 scope, grad_inputs[0], input, input_min, input_max,
103 QuantizeAndDequantizeV4Grad::Axis(axis));
104 grad_outputs->push_back(qdq_v4_grad.input_backprop);
105 grad_outputs->push_back(qdq_v4_grad.input_min_backprop);
106 grad_outputs->push_back(qdq_v4_grad.input_max_backprop);
107 return scope.status();
108 }
109 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV4",
110 QuantizeAndDequantizeV4GradHelper);
111
QuantizeAndDequantizeV3Grad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)112 Status QuantizeAndDequantizeV3Grad(const Scope& scope, const Operation& op,
113 const std::vector<Output>& grad_inputs,
114 std::vector<Output>* grad_outputs) {
115 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
116 grad_outputs->push_back(NoGradient());
117 grad_outputs->push_back(NoGradient());
118 grad_outputs->push_back(NoGradient());
119 return scope.status();
120 }
121 REGISTER_GRADIENT_OP("QuantizeAndDequantizeV3", QuantizeAndDequantizeV3Grad);
122
SplitGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)123 Status SplitGrad(const Scope& scope, const Operation& op,
124 const std::vector<Output>& grad_inputs,
125 std::vector<Output>* grad_outputs) {
126 grad_outputs->push_back(NoGradient());
127 grad_outputs->push_back(Concat(scope, grad_inputs, op.input(0)));
128 return scope.status();
129 }
130 REGISTER_GRADIENT_OP("Split", SplitGrad);
131
FillGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)132 Status FillGrad(const Scope& scope, const Operation& op,
133 const std::vector<Output>& grad_inputs,
134 std::vector<Output>* grad_outputs) {
135 // y = fill(fill_shape, x)
136 // No gradient returned for the fill_shape argument.
137 grad_outputs->push_back(NoGradient());
138 // The gradient for x (which must be a scalar) is just the sum of
139 // all the gradients from the shape it fills.
140 // We use ReduceSum to implement this, which needs an argument providing
141 // the indices of all the dimensions of the incoming gradient.
142 // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
143 auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
144 Const(scope, 1));
145 grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
146 return scope.status();
147 }
148 REGISTER_GRADIENT_OP("Fill", FillGrad);
149
DiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)150 Status DiagGrad(const Scope& scope, const Operation& op,
151 const std::vector<Output>& grad_inputs,
152 std::vector<Output>* grad_outputs) {
153 grad_outputs->push_back(DiagPart(scope, grad_inputs[0]));
154 return scope.status();
155 }
156 REGISTER_GRADIENT_OP("Diag", DiagGrad);
157
DiagPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)158 Status DiagPartGrad(const Scope& scope, const Operation& op,
159 const std::vector<Output>& grad_inputs,
160 std::vector<Output>* grad_outputs) {
161 grad_outputs->push_back(Diag(scope, grad_inputs[0]));
162 return scope.status();
163 }
164 REGISTER_GRADIENT_OP("DiagPart", DiagPartGrad);
165
MatrixDiagGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)166 Status MatrixDiagGrad(const Scope& scope, const Operation& op,
167 const std::vector<Output>& grad_inputs,
168 std::vector<Output>* grad_outputs) {
169 grad_outputs->push_back(MatrixDiagPart(scope, grad_inputs[0]));
170 return scope.status();
171 }
172 REGISTER_GRADIENT_OP("MatrixDiag", MatrixDiagGrad);
173
MatrixBandPartGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)174 Status MatrixBandPartGrad(const Scope& scope, const Operation& op,
175 const std::vector<Output>& grad_inputs,
176 std::vector<Output>* grad_outputs) {
177 auto num_lower = op.input(1);
178 auto num_upper = op.input(2);
179 grad_outputs->push_back(
180 MatrixBandPart(scope, grad_inputs[0], num_lower, num_upper));
181 grad_outputs->push_back(NoGradient());
182 grad_outputs->push_back(NoGradient());
183 return scope.status();
184 }
185 REGISTER_GRADIENT_OP("MatrixBandPart", MatrixBandPartGrad);
186
GatherNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)187 Status GatherNdGrad(const Scope& scope, const Operation& op,
188 const std::vector<Output>& grad_inputs,
189 std::vector<Output>* grad_outputs) {
190 auto ref = op.input(0);
191 auto ref_shape = Shape(scope, ref);
192 auto indices = op.input(1);
193 grad_outputs->push_back(ScatterNd(scope, indices, grad_inputs[0], ref_shape));
194 grad_outputs->push_back(NoGradient());
195 return scope.status();
196 }
197 REGISTER_GRADIENT_OP("GatherNd", GatherNdGrad);
198
CheckNumericsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)199 Status CheckNumericsGrad(const Scope& scope, const Operation& op,
200 const std::vector<Output>& grad_inputs,
201 std::vector<Output>* grad_outputs) {
202 string message;
203 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "message", &message));
204 string err_msg = strings::StrCat(
205 "Not a number (NaN) or infinity (Inf) values detected in gradient. ",
206 message);
207 grad_outputs->push_back(CheckNumerics(scope, grad_inputs[0], err_msg));
208 return scope.status();
209 }
210 REGISTER_GRADIENT_OP("CheckNumerics", CheckNumericsGrad);
211
ReshapeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)212 Status ReshapeGrad(const Scope& scope, const Operation& op,
213 const std::vector<Output>& grad_inputs,
214 std::vector<Output>* grad_outputs) {
215 auto input_shape = Shape(scope, op.input(0));
216 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
217 grad_outputs->push_back(NoGradient());
218 return scope.status();
219 }
220 REGISTER_GRADIENT_OP("Reshape", ReshapeGrad);
221
ExpandDimsGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)222 Status ExpandDimsGrad(const Scope& scope, const Operation& op,
223 const std::vector<Output>& grad_inputs,
224 std::vector<Output>* grad_outputs) {
225 auto input_shape = Shape(scope, op.input(0));
226 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
227 grad_outputs->push_back(NoGradient());
228 return scope.status();
229 }
230 REGISTER_GRADIENT_OP("ExpandDims", ExpandDimsGrad);
231
SqueezeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)232 Status SqueezeGrad(const Scope& scope, const Operation& op,
233 const std::vector<Output>& grad_inputs,
234 std::vector<Output>* grad_outputs) {
235 auto input_shape = Shape(scope, op.input(0));
236 grad_outputs->push_back(Reshape(scope, grad_inputs[0], input_shape));
237 return scope.status();
238 }
239 REGISTER_GRADIENT_OP("Squeeze", SqueezeGrad);
240
TransposeGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)241 Status TransposeGrad(const Scope& scope, const Operation& op,
242 const std::vector<Output>& grad_inputs,
243 std::vector<Output>* grad_outputs) {
244 auto inverted_perm = InvertPermutation(scope, op.input(1));
245 grad_outputs->push_back(Transpose(scope, grad_inputs[0], inverted_perm));
246 grad_outputs->push_back(NoGradient());
247 return scope.status();
248 }
249 REGISTER_GRADIENT_OP("Transpose", TransposeGrad);
250
ReverseSequenceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)251 Status ReverseSequenceGrad(const Scope& scope, const Operation& op,
252 const std::vector<Output>& grad_inputs,
253 std::vector<Output>* grad_outputs) {
254 auto seq_lengths = op.input(1);
255 int batch_dim;
256 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "batch_dim", &batch_dim));
257 int seq_dim;
258 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "seq_dim", &seq_dim));
259 grad_outputs->push_back(
260 ReverseSequence(scope, grad_inputs[0], seq_lengths, seq_dim,
261 ReverseSequence::BatchDim(batch_dim)));
262 grad_outputs->push_back(NoGradient());
263 return scope.status();
264 }
265 REGISTER_GRADIENT_OP("ReverseSequence", ReverseSequenceGrad);
266
ReverseGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)267 Status ReverseGrad(const Scope& scope, const Operation& op,
268 const std::vector<Output>& grad_inputs,
269 std::vector<Output>* grad_outputs) {
270 auto reverse_dims = op.input(1);
271 grad_outputs->push_back(Reverse(scope, grad_inputs[0], reverse_dims));
272 grad_outputs->push_back(NoGradient());
273 return scope.status();
274 }
275 REGISTER_GRADIENT_OP("ReverseV2", ReverseGrad);
276
ScatterNdGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)277 Status ScatterNdGrad(const Scope& scope, const Operation& op,
278 const std::vector<Output>& grad_inputs,
279 std::vector<Output>* grad_outputs) {
280 auto indices = op.input(0);
281 grad_outputs->push_back(NoGradient());
282 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
283 grad_outputs->push_back(NoGradient());
284 return scope.status();
285 }
286 REGISTER_GRADIENT_OP("ScatterNd", ScatterNdGrad);
287
ScatterNdNonAliasingAddGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)288 Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
289 const std::vector<Output>& grad_inputs,
290 std::vector<Output>* grad_outputs) {
291 auto indices = op.input(1);
292 grad_outputs->push_back(Identity(scope, grad_inputs[0]));
293 grad_outputs->push_back(NoGradient());
294 grad_outputs->push_back(GatherNd(scope, grad_inputs[0], indices));
295 return scope.status();
296 }
297 REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
298
299 template <bool IsPadV2>
PadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)300 Status PadGrad(const Scope& scope, const Operation& op,
301 const std::vector<Output>& grad_inputs,
302 std::vector<Output>* grad_outputs) {
303 auto x = op.input(0);
304 auto a = op.input(1); // [Rank(x), 2]
305 // Takes a slice of a. The 1st column. [Rank(x), 1].
306 auto size = Stack(scope, {Rank(scope, x), 1});
307 auto pad_before = Slice(scope, a, {0, 0}, size);
308 // Make it a 1-D tensor.
309 auto begin = Reshape(scope, pad_before, {-1});
310 grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
311 grad_outputs->push_back(NoGradient());
312 // PadV2 adds a "constant_values" input.
313 if (IsPadV2) {
314 grad_outputs->push_back(NoGradient());
315 }
316 return scope.status();
317 }
318 REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
319 REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
320
SpaceToBatchGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)321 Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
322 const std::vector<Output>& grad_inputs,
323 std::vector<Output>* grad_outputs) {
324 int block_size;
325 TF_RETURN_IF_ERROR(
326 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
327 grad_outputs->push_back(
328 BatchToSpace(scope, grad_inputs[0], op.input(1), block_size));
329 grad_outputs->push_back(NoGradient());
330 return scope.status();
331 }
332 REGISTER_GRADIENT_OP("SpaceToBatch", SpaceToBatchGrad);
333
SpaceToBatchNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)334 Status SpaceToBatchNDGrad(const Scope& scope, const Operation& op,
335 const std::vector<Output>& grad_inputs,
336 std::vector<Output>* grad_outputs) {
337 grad_outputs->push_back(
338 BatchToSpaceND(scope, grad_inputs[0], op.input(1), op.input(2)));
339 grad_outputs->push_back(NoGradient());
340 grad_outputs->push_back(NoGradient());
341 return scope.status();
342 }
343 REGISTER_GRADIENT_OP("SpaceToBatchND", SpaceToBatchNDGrad);
344
BatchToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)345 Status BatchToSpaceGrad(const Scope& scope, const Operation& op,
346 const std::vector<Output>& grad_inputs,
347 std::vector<Output>* grad_outputs) {
348 int block_size;
349 TF_RETURN_IF_ERROR(
350 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
351 grad_outputs->push_back(
352 SpaceToBatch(scope, grad_inputs[0], op.input(1), block_size));
353 grad_outputs->push_back(NoGradient());
354 return scope.status();
355 }
356 REGISTER_GRADIENT_OP("BatchToSpace", BatchToSpaceGrad);
357
BatchToSpaceNDGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)358 Status BatchToSpaceNDGrad(const Scope& scope, const Operation& op,
359 const std::vector<Output>& grad_inputs,
360 std::vector<Output>* grad_outputs) {
361 grad_outputs->push_back(
362 SpaceToBatchND(scope, grad_inputs[0], op.input(1), op.input(2)));
363 grad_outputs->push_back(NoGradient());
364 grad_outputs->push_back(NoGradient());
365 return scope.status();
366 }
367 REGISTER_GRADIENT_OP("BatchToSpaceND", BatchToSpaceNDGrad);
368
SpaceToDepthGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)369 Status SpaceToDepthGrad(const Scope& scope, const Operation& op,
370 const std::vector<Output>& grad_inputs,
371 std::vector<Output>* grad_outputs) {
372 int block_size;
373 TF_RETURN_IF_ERROR(
374 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
375 grad_outputs->push_back(DepthToSpace(scope, grad_inputs[0], block_size));
376 return scope.status();
377 }
378 REGISTER_GRADIENT_OP("SpaceToDepth", SpaceToDepthGrad);
379
DepthToSpaceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)380 Status DepthToSpaceGrad(const Scope& scope, const Operation& op,
381 const std::vector<Output>& grad_inputs,
382 std::vector<Output>* grad_outputs) {
383 int block_size;
384 TF_RETURN_IF_ERROR(
385 GetNodeAttr(op.node()->attrs(), "block_size", &block_size));
386 grad_outputs->push_back(SpaceToDepth(scope, grad_inputs[0], block_size));
387 return scope.status();
388 }
389 REGISTER_GRADIENT_OP("DepthToSpace", DepthToSpaceGrad);
390
MirrorPadGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)391 Status MirrorPadGrad(const Scope& scope, const Operation& op,
392 const std::vector<Output>& grad_inputs,
393 std::vector<Output>* grad_outputs) {
394 string mode;
395 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
396 grad_outputs->push_back(tensorflow::ops::internal::MirrorPadGrad(
397 scope, grad_inputs[0], op.input(1), mode));
398 grad_outputs->push_back(NoGradient());
399 return scope.status();
400 }
401 REGISTER_GRADIENT_OP("MirrorPad", MirrorPadGrad);
402
403 // TODO(suharshs): b/34770860. This gradient was within 1e-3 but not 1e-4.
MirrorPadGradGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)404 Status MirrorPadGradGrad(const Scope& scope, const Operation& op,
405 const std::vector<Output>& grad_inputs,
406 std::vector<Output>* grad_outputs) {
407 string mode;
408 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "mode", &mode));
409 grad_outputs->push_back(MirrorPad(scope, grad_inputs[0], op.input(1), mode));
410 grad_outputs->push_back(NoGradient());
411 return scope.status();
412 }
413 REGISTER_GRADIENT_OP("MirrorPadGrad", MirrorPadGradGrad);
414
StridedSliceGradHelper(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)415 Status StridedSliceGradHelper(const Scope& scope, const Operation& op,
416 const std::vector<Output>& grad_inputs,
417 std::vector<Output>* grad_outputs) {
418 Input x = Shape(scope, op.input(0));
419 Input begin = op.input(1);
420 Input end = op.input(2);
421 Input strides = op.input(3);
422 int64 begin_mask;
423 int64 end_mask;
424 int64 ellipsis_mask;
425 int64 new_axis_mask;
426 int64 shrink_axis_mask;
427 TF_RETURN_IF_ERROR(
428 GetNodeAttr(op.node()->attrs(), "begin_mask", &begin_mask));
429 TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "end_mask", &end_mask));
430 TF_RETURN_IF_ERROR(
431 GetNodeAttr(op.node()->attrs(), "ellipsis_mask", &ellipsis_mask));
432 TF_RETURN_IF_ERROR(
433 GetNodeAttr(op.node()->attrs(), "new_axis_mask", &new_axis_mask));
434 TF_RETURN_IF_ERROR(
435 GetNodeAttr(op.node()->attrs(), "shrink_axis_mask", &shrink_axis_mask));
436 grad_outputs->push_back(
437 StridedSliceGrad(scope, x, begin, end, strides, grad_inputs[0],
438 StridedSliceGrad::BeginMask(begin_mask)
439 .EndMask(end_mask)
440 .EllipsisMask(ellipsis_mask)
441 .NewAxisMask(new_axis_mask)
442 .ShrinkAxisMask(shrink_axis_mask)));
443 // No gradients returned for begin, end and strides
444 grad_outputs->push_back(NoGradient());
445 grad_outputs->push_back(NoGradient());
446 grad_outputs->push_back(NoGradient());
447 return scope.status();
448 }
449 REGISTER_GRADIENT_OP("StridedSlice", StridedSliceGradHelper);
450
SliceGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)451 Status SliceGrad(const Scope& scope, const Operation& op,
452 const std::vector<Output>& grad_inputs,
453 std::vector<Output>* grad_outputs) {
454 // Propagate the incoming gradient along all the selected values,
455 // and zero everywhere else. Use the Pad operator for this.
456 //
457 // First create an Nx2 padding where N is the number of input
458 // dimensions. The first column is the number of prepended zeros
459 // for each dimension, and the second column is the number of
460 // appended zeros.
461 //
462 // The first column is just the begin vector.
463 // The second column is the shape of the input element-wise
464 // subtracted by begin+size
465
466 // Running example:
467 // input.shape = [3, 5, 3]
468 // begin = [1, 2, 1], size = [1, 3, 2]
469 Input input = op.input(0);
470 Input begin = op.input(1);
471 // input_rank = 3
472 auto input_rank = Rank(scope, input);
473 // slice_size = [1, 3, 2]
474 auto slice_size = Shape(scope, op.output(0));
475 // padding_shape = [3, 1]
476 auto padding_shape = Stack(scope, {input_rank, 1});
477 // before_padding = [[1]
478 // [2]
479 // [1]]
480 Input before_padding = Reshape(scope, begin, padding_shape);
481 // after_padding_sizes = shape(input) - slice_size - begin
482 // = [3, 5, 3] - [1, 3, 2] - [1, 2, 1]
483 // = [1, 0, 0]
484 auto after_padding_sizes =
485 Sub(scope, Sub(scope, Shape(scope, input), slice_size), begin);
486 // after_padding = [[1]
487 // [0]
488 // [0]]
489 Input after_padding = Reshape(scope, after_padding_sizes, padding_shape);
490 // paddings = [[1 1]
491 // [2 0]
492 // [1 0]]
493 auto paddings =
494 Concat(scope, {before_padding, after_padding}, Const(scope, 1));
495 grad_outputs->push_back(Pad(scope, grad_inputs[0], paddings));
496 // Nothing propagated for "begin" and "size" inputs
497 grad_outputs->push_back(NoGradient());
498 grad_outputs->push_back(NoGradient());
499 return scope.status();
500 }
501 REGISTER_GRADIENT_OP("Slice", SliceGrad);
502
503 } // anonymous namespace
504 } // namespace ops
505 } // namespace tensorflow
506