1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17 #include "tensorflow/core/framework/function.h"
18 #include "tensorflow/core/lib/core/errors.h"
19
20 namespace tensorflow {
21
22 typedef FunctionDefHelper FDH;
23
24 REGISTER_OP_NO_GRADIENT("Shape");
25 REGISTER_OP_NO_GRADIENT("Rank");
26 REGISTER_OP_NO_GRADIENT("Size");
27 REGISTER_OP_NO_GRADIENT("ZerosLike");
28 REGISTER_OP_NO_GRADIENT("OnesLike");
29 REGISTER_OP_NO_GRADIENT("Const");
30 REGISTER_OP_NO_GRADIENT("EditDistance");
31 REGISTER_OP_NO_GRADIENT("StopGradient");
32
ReshapeGrad(const AttrSlice & attrs,FunctionDef * g)33 Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) {
34 // clang-format off
35 *g = FDH::Define(
36 // Arg defs
37 {"x: T", "shape: int32", "dy: T"},
38 // Ret val defs
39 {"dx: T", "dshape: int32"},
40 // Attr defs
41 {"T: type"},
42 // Nodes
43 {
44 {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
45 {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
46 {{"dshape"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
47 });
48 // clang-format on
49 return Status::OK();
50 }
51 REGISTER_OP_GRADIENT("Reshape", ReshapeGrad);
52 REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad);
53
SqueezeGrad(const AttrSlice & attrs,FunctionDef * g)54 Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) {
55 // clang-format off
56 *g = FDH::Define(
57 // Arg defs
58 {"x: T", "dy: T"},
59 // Ret val defs
60 {"dx: T"},
61 // Attr defs
62 {"T: type"},
63 // Nodes
64 {
65 {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
66 {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
67 });
68 // clang-format on
69 return Status::OK();
70 }
71 REGISTER_OP_GRADIENT("Squeeze", SqueezeGrad);
72
IdentityGrad(const AttrSlice & attrs,FunctionDef * g)73 Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) {
74 // clang-format off
75 *g = FDH::Define(
76 // Arg defs
77 {"x: T", "dy: T"},
78 // Ret val defs
79 {"dx: T"},
80 // Attr defs
81 {"T: type"},
82 // Nodes
83 {
84 {{"dx"}, "Identity", {"dy"}, {{"T", "$T"}}},
85 });
86 // clang-format on
87 VLOG(1) << "IdentityGrad " << DebugString(*g);
88 return Status::OK();
89 }
90 REGISTER_OP_GRADIENT("Identity", IdentityGrad);
91
PackGrad(const AttrSlice & attrs,FunctionDef * g)92 Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
93 // clang-format off
94 *g = FDH::Create(
95 "_",
96 // Arg defs
97 {"x: N*T", "dy: T"},
98 // Ret val defs
99 {"dx: N*T"},
100 // Attr defs
101 {"T: type", "N: int", "axis: int"},
102 // Nodes
103 {
104 {
105 {"dx"},
106 "Unpack",
107 {"dy"},
108 {{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
109 },
110 },
111 {{"dx", "dx:output"}});
112 // clang-format on
113 VLOG(1) << "PackGrad " << DebugString(*g);
114 return Status::OK();
115 }
116 REGISTER_OP_GRADIENT("Pack", PackGrad);
117
UnpackGrad(const AttrSlice & attrs,FunctionDef * g)118 Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) {
119 // clang-format off
120 *g = FDH::Define(
121 // Arg defs
122 {"x: T", "dy: num*T"},
123 // Ret val defs
124 {"dx: T"},
125 // Attr defs
126 {"T: type", "num: int", "axis: int"},
127 // Nodes
128 {
129 {
130 {"dx"},
131 "Pack",
132 {"dy"},
133 {{"T", "$T"}, {"N", "$num"}, {"axis", "$axis"}}
134 },
135 });
136 // clang-format on
137 VLOG(1) << "UnpackGrad " << DebugString(*g);
138 return Status::OK();
139 }
140 REGISTER_OP_GRADIENT("Unpack", UnpackGrad);
141
ConcatGradHelper(const AttrSlice & attrs,FunctionDef * g,bool dim_is_last_arg)142 Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
143 bool dim_is_last_arg) {
144 int N;
145 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
146 DataType T;
147 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
148
149 std::vector<string> shape_i;
150 std::vector<string> offset_i;
151 std::vector<string> dx_i;
152 for (int i = 0; i < N; ++i) {
153 shape_i.push_back(strings::StrCat("shapes:output:", i));
154 offset_i.push_back(strings::StrCat("offset:offset:", i));
155 dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
156 }
157 DataTypeVector dtype_list(N, T);
158
159 // ConcatGrad(dim, x, dy):
160 // for i in range(N):
161 // dx[i] = Slice(dy, offset[i], shape[x[i]]),
162 // where offset[i] is the offset of x[i] in the output y,
163 // which is the same as dx[i]'s offset within dy.
164 std::vector<FDH::Node> nodes{
165 {{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
166 {{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
167 {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
168 {{"dx"},
169 "_ListToArray",
170 dx_i,
171 {{"T", "$T"}, {"N", "$N"}, {"Tin", DataTypeVector(N, T)}}}};
172
173 // For each dx[i], we take a slice of dy. The offset and size of the
174 // slice is given by offset[i] and shape[i].
175 for (int i = 0; i < N; ++i) {
176 nodes.push_back({{strings::StrCat("dx_", i)},
177 "Slice",
178 {"dy", offset_i[i], shape_i[i]},
179 {{"T", "$T"}, {"Index", DT_INT32}}});
180 }
181 if (dim_is_last_arg) {
182 // clang-format off
183 *g = FDH::Create(
184 "_",
185 // Arg defs
186 {"x: N*T", "dim: int32", "dy: T"},
187 // Return signature
188 {"dx: N*T", "d_dim: int32"},
189 // Attr defs
190 {"T: type", "N: int"},
191 // Nodes
192 nodes,
193 // Return values
194 {{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
195 // clang-format on
196 } else {
197 // clang-format off
198 *g = FDH::Create(
199 "_",
200 // Arg defs
201 {"dim: int32", "x: N*T", "dy: T"},
202 // Return signature
203 {"d_dim: int32", "dx: N*T"},
204 // Attr defs
205 {"T: type", "N: int"},
206 // Nodes
207 nodes,
208 // Return values
209 {{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
210 // clang-format on
211 }
212 VLOG(1) << "ConcatGrad " << DebugString(*g);
213 return Status::OK();
214 }
215
ConcatGrad(const AttrSlice & attrs,FunctionDef * g)216 Status ConcatGrad(const AttrSlice& attrs, FunctionDef* g) {
217 return ConcatGradHelper(attrs, g, false);
218 }
219
ConcatGradV2(const AttrSlice & attrs,FunctionDef * g)220 Status ConcatGradV2(const AttrSlice& attrs, FunctionDef* g) {
221 return ConcatGradHelper(attrs, g, true);
222 }
223
224 REGISTER_OP_GRADIENT("Concat", ConcatGrad);
225 REGISTER_OP_GRADIENT("ConcatV2", ConcatGradV2);
226
SplitGrad(const AttrSlice & attrs,FunctionDef * g)227 Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) {
228 // clang-format off
229 *g = FDH::Define(
230 // Arg defs
231 {"dim: int32", "x: T", "dy: num_split*T"},
232 // Ret val defs
233 {"d_dim: int32", "dx: T"},
234 // Attr defs
235 {"T: type", "num_split: int"},
236 // Nodes
237 {
238 {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
239 {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}}
240 });
241 // clang-format on
242 VLOG(1) << "SplitGrad " << DebugString(*g);
243 return Status::OK();
244 }
245 REGISTER_OP_GRADIENT("Split", SplitGrad);
246
SplitVGrad(const AttrSlice & attrs,FunctionDef * g)247 Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) {
248 // clang-format off
249 *g = FDH::Define(
250 // Arg defs
251 {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"},
252 // Ret val defs
253 {"dx: T", "d_size_splits: Tlen", "d_dim: int32"},
254 // Attr defs
255 {"T: type", "Tlen: type", "num_split: int"},
256 // Nodes
257 {
258 {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}},
259 {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}},
260 {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
261 });
262 // clang-format on
263 VLOG(1) << "SplitVGrad " << DebugString(*g);
264 return Status::OK();
265 }
266 REGISTER_OP_GRADIENT("SplitV", SplitVGrad);
267
ArrayToListGrad(const AttrSlice & attrs,FunctionDef * g)268 Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) {
269 int N;
270 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
271 std::vector<string> dys;
272 dys.reserve(N);
273 for (int i = 0; i < N; ++i) {
274 dys.push_back(strings::StrCat("dy:", i));
275 }
276 // clang-format off
277 *g = FDH::Define(
278 // Arg defs
279 {"x: N*T", "dy: out_types"},
280 // Ret val defs
281 {"dx: N*T"},
282 // Attr defs
283 {"T: type", "N: int", "out_types: list(type)"},
284 // Nodes
285 {
286 {{"dx"}, "_ListToArray", dys,
287 {{"T", "$T"}, {"N", "$N"}, {"Tin", "$out_types"}}}
288 });
289 // clang-format on
290 VLOG(1) << "ArrayToListGrad " << DebugString(*g);
291 return Status::OK();
292 }
293 REGISTER_OP_GRADIENT("_ArrayToList", ArrayToListGrad);
294
ListToArrayGrad(const AttrSlice & attrs,FunctionDef * g)295 Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) {
296 // clang-format off
297 *g = FDH::Define(
298 // Arg defs
299 {"x: Tin", "dy: N*T"},
300 // Ret val defs
301 {"dx: Tin"},
302 // Attr defs
303 {"T: type", "N: int", "Tin: list(type)"},
304 // Nodes
305 {
306 {{"dx"}, "_ArrayToList", {"dy"},
307 {{"T", "$T"}, {"N", "$N"}, {"out_types", "$Tin"}}}
308 });
309 // clang-format on
310 VLOG(1) << "ListToArrayGrad " << DebugString(*g);
311 return Status::OK();
312 }
313 REGISTER_OP_GRADIENT("_ListToArray", ListToArrayGrad);
314
FillGrad(const AttrSlice & attrs,FunctionDef * g)315 Status FillGrad(const AttrSlice& attrs, FunctionDef* g) {
316 *g = FDH::Define(
317 // Arg defs
318 {"dims: int32", "x: T", "dy: T"},
319 // Ret val defs
320 {"d_dims: int32", "dx: T"},
321 // Attr defs
322 {"T: type"},
323 // Nodes
324 {
325 {{"d_dims"}, "ZerosLike", {"dims"}, {{"T", DT_INT32}}},
326 FDH::Const("zero", 0),
327 {{"rank"}, "Rank", {"dy"}, {{"T", "$T"}}},
328 FDH::Const("one", 1),
329 {{"r"}, "Range", {"zero", "rank", "one"}, {}},
330 // dx = sum(dy)
331 {{"dx"}, "Sum", {"dy", "r"}, {{"T", "$T"}}},
332 });
333 VLOG(1) << "FillGrad " << DebugString(*g);
334 return Status::OK();
335 }
336 REGISTER_OP_GRADIENT("Fill", FillGrad);
337
TransposeGrad(const AttrSlice & attrs,FunctionDef * g)338 Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
339 *g = FDH::Define(
340 // Arg defs
341 {"x: T", "p: int32", "dy: T"},
342 // Ret val defs
343 {"dx: T", "dp: int32"},
344 // Attr defs
345 {"T: type"},
346 // Nodes
347 {
348 {{"q"}, "InvertPermutation", {"p"}, {}},
349 {{"dx"}, "Transpose", {"dy", "q"}, {{"T", "$T"}}},
350 {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
351 });
352 VLOG(1) << "TransposeGrad " << DebugString(*g);
353 return Status::OK();
354 }
355 REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
356
GatherNdGrad(const AttrSlice & attrs,FunctionDef * g)357 Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) {
358 // clang-format off
359 *g = FDH::Define(
360 // Arg defs
361 {"params: Tparams", "indices: Tindices", "doutput: Tparams"},
362 // Ret val defs
363 {"dparams: Tparams", "dindices: Tindices"},
364 // Attr defs
365 {"Tparams: type", "Tindices: type"},
366 // Nodes
367 {
368 {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}},
369 {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"},
370 {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}},
371 {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}},
372 });
373 // clang-format on
374 return Status::OK();
375 }
376 REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad);
377
ConjugateTransposeGrad(const AttrSlice & attrs,FunctionDef * g)378 Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
379 *g = FDH::Define(
380 // Arg defs
381 {"x: T", "p: int32", "dy: T"},
382 // Ret val defs
383 {"dx: T", "dp: int32"},
384 // Attr defs
385 {"T: type"},
386 // Nodes
387 {
388 {{"q"}, "InvertPermutation", {"p"}, {}},
389 {{"dx"}, "ConjugateTranspose", {"dy", "q"}, {{"T", "$T"}}},
390 {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
391 });
392 VLOG(1) << "ConjugateTransposeGrad " << DebugString(*g);
393 return Status::OK();
394 }
395 REGISTER_OP_GRADIENT("ConjugateTranspose", ConjugateTransposeGrad);
396
ReverseGrad(const AttrSlice & attrs,FunctionDef * g)397 Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) {
398 *g = FDH::Define(
399 // Arg defs
400 {"x: T", "d: bool", "dy: T"},
401 // Ret val defs
402 {"dx: T", "dd: bool"},
403 // Attr defs
404 {"T: type"},
405 // Nodes
406 {
407 {{"dx"}, "Reverse", {"dy", "d"}, {{"T", "$T"}}},
408 {{"dd"}, "ZerosLike", {"d"}, {{"T", DT_BOOL}}},
409 });
410 VLOG(1) << "ReverseGrad " << DebugString(*g);
411 return Status::OK();
412 }
413 REGISTER_OP_GRADIENT("Reverse", ReverseGrad);
414
ReverseV2Grad(const AttrSlice & attrs,FunctionDef * g)415 Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) {
416 DataType itype;
417 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
418 if (itype != DT_INT32) {
419 return errors::Unimplemented(
420 "ReverseV2Grad for int64 index are not supported.");
421 }
422 *g = FDH::Define(
423 // Arg defs
424 {"x: T", "d: int32", "dy: T"},
425 // Ret val defs
426 {"dx: T", "dd: int32"},
427 // Attr defs
428 {"T: type", "Tidx: {int32, int64}"},
429 // Nodes
430 {
431 {{"dx"}, "ReverseV2", {"dy", "d"}, {{"T", "$T"}}},
432 {{"dd"}, "ZerosLike", {"d"}, {{"T", "$Tidx"}}},
433 });
434 VLOG(1) << "ReverseGrad " << DebugString(*g);
435 return Status::OK();
436 }
437 REGISTER_OP_GRADIENT("ReverseV2", ReverseV2Grad);
438
SliceGrad(const AttrSlice & attrs,FunctionDef * g)439 Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
440 DataType itype;
441 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
442 if (itype != DT_INT32) {
443 return errors::Unimplemented(
444 "SliceGrad for int64 index are not supported.");
445 }
446 *g = FDH::Define(
447 // Arg defs
448 {"x: T", "begin: int32", "size: int32", "dy: T"},
449 // Ret val defs
450 {"dx: T", "begin_grad: int32", "size_grad: int32"},
451 // Attr defs
452 {"T: type"},
453 // Nodes
454 {// paddings = concat(1, [begin, shape(x) - begin - size])
455 FDH::Const("one", 1),
456 {{"b1"}, "ExpandDims", {"begin", "one"}, {{"T", DT_INT32}}},
457 {{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
458 {{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
459 {{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
460 {{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
461 {{"paddings"},
462 "Concat",
463 {"one", "b1", "a1"},
464 {{"N", 2}, {"T", DT_INT32}}},
465 // dx = Pad(dy, paddings)
466 {{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
467 {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
468 {{"size_grad"}, "ZerosLike", {"size"}, {{"T", DT_INT32}}}});
469 VLOG(1) << "SliceGrad " << DebugString(*g);
470 return Status::OK();
471 }
472 REGISTER_OP_GRADIENT("Slice", SliceGrad);
473
StridedSliceGrad(const AttrSlice & attrs,FunctionDef * g)474 Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) {
475 DataType itype;
476 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
477 if (itype != DT_INT32) {
478 return errors::Unimplemented(
479 "SliceGrad for int64 index are not supported.");
480 }
481
482 *g = FDH::Define(
483 // Arg defs
484 {"x: T", "begin: int32", "end: int32", "stride: int32", "dy: T"},
485 // Ret val defs
486 {"dx: T", "begin_grad: int32", "end_grad: int32", "stride_grad: int32"},
487 // Attr defs
488 {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
489 "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
490 {// Nodes
491 {{{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
492 {{"dx"},
493 "StridedSliceGrad",
494 {"xs", "begin", "end", "stride", "dy"},
495 {{"T", "$T"},
496 {"Index", "$Index"},
497 {"begin_mask", "$begin_mask"},
498 {"end_mask", "$end_mask"},
499 {"ellipsis_mask", "$ellipsis_mask"},
500 {"new_axis_mask", "$new_axis_mask"},
501 {"shrink_axis_mask", "$shrink_axis_mask"}}},
502 {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
503 {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
504 {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}}}});
505
506 VLOG(1) << "StridedSliceGrad " << DebugString(*g);
507 return Status::OK();
508 }
509 REGISTER_OP_GRADIENT("StridedSlice", StridedSliceGrad);
510
StridedSliceGradGrad(const AttrSlice & attrs,FunctionDef * g)511 Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
512 DataType itype;
513 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
514 if (itype != DT_INT32) {
515 return errors::Unimplemented(
516 "SliceGrad for int64 index are not supported.");
517 }
518
519 // TODO(aselle): Shouldn't the int32 tensors return zeros of shape like
520 // dy_grad?
521 // I'm following slice's behavior for now.
522 *g = FDH::Define(
523 // Arg defs
524 {"shape: int32", "begin: int32", "end: int32", "stride: int32", "dy: T",
525 "grad: T"},
526 // Ret val defs
527 {"shape_grad: int32", "begin_grad: int32", "end_grad: int32",
528 "stride_grad: int32", "dy_grad: T"},
529 // Attr defs
530 {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
531 "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
532 {// Nodes
533 {{{"shape_grad"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
534 {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
535 {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
536 {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}},
537 {{"dy_grad"},
538 "StridedSlice",
539 {"grad", "begin", "end", "stride"},
540 {{"T", "$T"},
541 {"Index", "$Index"},
542 {"begin_mask", "$begin_mask"},
543 {"end_mask", "$end_mask"},
544 {"ellipsis_mask", "$ellipsis_mask"},
545 {"new_axis_mask", "$new_axis_mask"},
546 {"shrink_axis_mask", "$shrink_axis_mask"}}}}});
547
548 VLOG(1) << "StridedSliceGrad " << DebugString(*g);
549 return Status::OK();
550 }
551 REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
552
BroadcastToGrad(const AttrSlice & attrs,FunctionDef * g)553 Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
554 DataType itype;
555 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
556 if (itype != DT_INT32) {
557 return errors::Unimplemented(
558 "BroadcastToGrad for int64 index are not supported.");
559 }
560 std::vector<FDH::Node> nodes = {
561 {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
562 {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
563 {{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
564 {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
565 {{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
566 *g = FDH::Define(
567 // Arg defs
568 {"x: T", "shape: int32", "dy: T"},
569 // Ret val defs
570 {"dx: T", "dshape: Tidx"},
571 // Attr defs
572 {{"T: type"}, {"Tidx: {int32, int64}"}},
573 // Nodes
574 nodes);
575 return Status::OK();
576 }
577 REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
578
579 } // end namespace tensorflow
580