• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 #include "tensorflow/core/framework/function.h"
18 #include "tensorflow/core/lib/core/errors.h"
19 
20 namespace tensorflow {
21 
22 typedef FunctionDefHelper FDH;
23 
24 REGISTER_OP_NO_GRADIENT("Shape");
25 REGISTER_OP_NO_GRADIENT("Rank");
26 REGISTER_OP_NO_GRADIENT("Size");
27 REGISTER_OP_NO_GRADIENT("ZerosLike");
28 REGISTER_OP_NO_GRADIENT("OnesLike");
29 REGISTER_OP_NO_GRADIENT("Const");
30 REGISTER_OP_NO_GRADIENT("EditDistance");
31 REGISTER_OP_NO_GRADIENT("StopGradient");
32 
ReshapeGrad(const AttrSlice & attrs,FunctionDef * g)33 Status ReshapeGrad(const AttrSlice& attrs, FunctionDef* g) {
34   // clang-format off
35   *g = FDH::Define(
36       // Arg defs
37       {"x: T", "shape: int32", "dy: T"},
38       // Ret val defs
39       {"dx: T", "dshape: int32"},
40       // Attr defs
41       {"T: type"},
42       // Nodes
43       {
44         {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
45         {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
46         {{"dshape"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
47       });
48   // clang-format on
49   return Status::OK();
50 }
51 REGISTER_OP_GRADIENT("Reshape", ReshapeGrad);
52 REGISTER_OP_GRADIENT("ExpandDims", ReshapeGrad);
53 
SqueezeGrad(const AttrSlice & attrs,FunctionDef * g)54 Status SqueezeGrad(const AttrSlice& attrs, FunctionDef* g) {
55   // clang-format off
56   *g = FDH::Define(
57       // Arg defs
58       {"x: T", "dy: T"},
59       // Ret val defs
60       {"dx: T"},
61       // Attr defs
62       {"T: type"},
63       // Nodes
64       {
65         {{"x_shape"}, "Shape", {"x"}, {{"T", "$T"}}},
66         {{"dx"}, "Reshape", {"dy", "x_shape"}, {{"T", "$T"}}},
67       });
68   // clang-format on
69   return Status::OK();
70 }
71 REGISTER_OP_GRADIENT("Squeeze", SqueezeGrad);
72 
IdentityGrad(const AttrSlice & attrs,FunctionDef * g)73 Status IdentityGrad(const AttrSlice& attrs, FunctionDef* g) {
74   // clang-format off
75   *g = FDH::Define(
76       // Arg defs
77       {"x: T", "dy: T"},
78       // Ret val defs
79       {"dx: T"},
80       // Attr defs
81       {"T: type"},
82       // Nodes
83       {
84         {{"dx"}, "Identity", {"dy"}, {{"T", "$T"}}},
85       });
86   // clang-format on
87   VLOG(1) << "IdentityGrad " << DebugString(*g);
88   return Status::OK();
89 }
90 REGISTER_OP_GRADIENT("Identity", IdentityGrad);
91 
PackGrad(const AttrSlice & attrs,FunctionDef * g)92 Status PackGrad(const AttrSlice& attrs, FunctionDef* g) {
93   // clang-format off
94   *g = FDH::Create(
95       "_",
96       // Arg defs
97       {"x: N*T", "dy: T"},
98       // Ret val defs
99       {"dx: N*T"},
100       // Attr defs
101       {"T: type", "N: int", "axis: int"},
102       // Nodes
103       {
104         {
105           {"dx"},
106           "Unpack",
107           {"dy"},
108           {{"T", "$T"}, {"num", "$N"}, {"axis", "$axis"}}
109         },
110       },
111       {{"dx", "dx:output"}});
112   // clang-format on
113   VLOG(1) << "PackGrad " << DebugString(*g);
114   return Status::OK();
115 }
116 REGISTER_OP_GRADIENT("Pack", PackGrad);
117 
UnpackGrad(const AttrSlice & attrs,FunctionDef * g)118 Status UnpackGrad(const AttrSlice& attrs, FunctionDef* g) {
119   // clang-format off
120   *g = FDH::Define(
121       // Arg defs
122       {"x: T", "dy: num*T"},
123       // Ret val defs
124       {"dx: T"},
125       // Attr defs
126       {"T: type", "num: int", "axis: int"},
127       // Nodes
128       {
129         {
130           {"dx"},
131           "Pack",
132           {"dy"},
133           {{"T", "$T"}, {"N", "$num"}, {"axis", "$axis"}}
134         },
135       });
136   // clang-format on
137   VLOG(1) << "UnpackGrad " << DebugString(*g);
138   return Status::OK();
139 }
140 REGISTER_OP_GRADIENT("Unpack", UnpackGrad);
141 
ConcatGradHelper(const AttrSlice & attrs,FunctionDef * g,bool dim_is_last_arg)142 Status ConcatGradHelper(const AttrSlice& attrs, FunctionDef* g,
143                         bool dim_is_last_arg) {
144   int N;
145   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
146   DataType T;
147   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
148 
149   std::vector<string> shape_i;
150   std::vector<string> offset_i;
151   std::vector<string> dx_i;
152   for (int i = 0; i < N; ++i) {
153     shape_i.push_back(strings::StrCat("shapes:output:", i));
154     offset_i.push_back(strings::StrCat("offset:offset:", i));
155     dx_i.push_back(strings::StrCat("dx_", i, ":output:0"));
156   }
157   DataTypeVector dtype_list(N, T);
158 
159   // ConcatGrad(dim, x, dy):
160   //   for i in range(N):
161   //     dx[i] = Slice(dy, offset[i], shape[x[i]]),
162   // where offset[i] is the offset of x[i] in the output y,
163   // which is the same as dx[i]'s offset within dy.
164   std::vector<FDH::Node> nodes{
165       {{"shapes"}, "ShapeN", {"x"}, {{"T", "$T"}, {"N", "$N"}}},
166       {{"offset"}, "ConcatOffset", {"dim", "shapes:output"}, {{"N", "$N"}}},
167       {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
168       {{"dx"},
169        "_ListToArray",
170        dx_i,
171        {{"T", "$T"}, {"N", "$N"}, {"Tin", DataTypeVector(N, T)}}}};
172 
173   // For each dx[i], we take a slice of dy. The offset and size of the
174   // slice is given by offset[i] and shape[i].
175   for (int i = 0; i < N; ++i) {
176     nodes.push_back({{strings::StrCat("dx_", i)},
177                      "Slice",
178                      {"dy", offset_i[i], shape_i[i]},
179                      {{"T", "$T"}, {"Index", DT_INT32}}});
180   }
181   if (dim_is_last_arg) {
182     // clang-format off
183     *g = FDH::Create(
184         "_",
185         // Arg defs
186         {"x: N*T", "dim: int32", "dy: T"},
187         // Return signature
188         {"dx: N*T", "d_dim: int32"},
189         // Attr defs
190         {"T: type", "N: int"},
191         // Nodes
192         nodes,
193         // Return values
194         {{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
195     // clang-format on
196   } else {
197     // clang-format off
198     *g = FDH::Create(
199         "_",
200         // Arg defs
201         {"dim: int32", "x: N*T", "dy: T"},
202         // Return signature
203         {"d_dim: int32", "dx: N*T"},
204         // Attr defs
205         {"T: type", "N: int"},
206         // Nodes
207         nodes,
208         // Return values
209         {{"dx", "dx:output"}, {"d_dim", "d_dim:y:0"}});
210     // clang-format on
211   }
212   VLOG(1) << "ConcatGrad " << DebugString(*g);
213   return Status::OK();
214 }
215 
ConcatGrad(const AttrSlice & attrs,FunctionDef * g)216 Status ConcatGrad(const AttrSlice& attrs, FunctionDef* g) {
217   return ConcatGradHelper(attrs, g, false);
218 }
219 
ConcatGradV2(const AttrSlice & attrs,FunctionDef * g)220 Status ConcatGradV2(const AttrSlice& attrs, FunctionDef* g) {
221   return ConcatGradHelper(attrs, g, true);
222 }
223 
224 REGISTER_OP_GRADIENT("Concat", ConcatGrad);
225 REGISTER_OP_GRADIENT("ConcatV2", ConcatGradV2);
226 
SplitGrad(const AttrSlice & attrs,FunctionDef * g)227 Status SplitGrad(const AttrSlice& attrs, FunctionDef* g) {
228   // clang-format off
229   *g = FDH::Define(
230       // Arg defs
231       {"dim: int32", "x: T", "dy: num_split*T"},
232       // Ret val defs
233       {"d_dim: int32", "dx: T"},
234       // Attr defs
235       {"T: type", "num_split: int"},
236       // Nodes
237       {
238         {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
239         {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}}
240       });
241   // clang-format on
242   VLOG(1) << "SplitGrad " << DebugString(*g);
243   return Status::OK();
244 }
245 REGISTER_OP_GRADIENT("Split", SplitGrad);
246 
SplitVGrad(const AttrSlice & attrs,FunctionDef * g)247 Status SplitVGrad(const AttrSlice& attrs, FunctionDef* g) {
248   // clang-format off
249   *g = FDH::Define(
250       // Arg defs
251       {"x: T", "size_splits: Tlen", "dim: int32", "dy: num_split*T"},
252       // Ret val defs
253       {"dx: T", "d_size_splits: Tlen", "d_dim: int32"},
254       // Attr defs
255       {"T: type", "Tlen: type", "num_split: int"},
256       // Nodes
257       {
258         {{"dx"}, "Concat", {"dim", "dy"}, {{"T", "$T"}, {"N", "$num_split"}}},
259         {{"d_size_splits"}, "ZerosLike", {"size_splits"}, {{"T", "$Tlen"}}},
260         {{"d_dim"}, "ZerosLike", {"dim"}, {{"T", DT_INT32}}},
261       });
262   // clang-format on
263   VLOG(1) << "SplitVGrad " << DebugString(*g);
264   return Status::OK();
265 }
266 REGISTER_OP_GRADIENT("SplitV", SplitVGrad);
267 
ArrayToListGrad(const AttrSlice & attrs,FunctionDef * g)268 Status ArrayToListGrad(const AttrSlice& attrs, FunctionDef* g) {
269   int N;
270   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &N));
271   std::vector<string> dys;
272   dys.reserve(N);
273   for (int i = 0; i < N; ++i) {
274     dys.push_back(strings::StrCat("dy:", i));
275   }
276   // clang-format off
277   *g = FDH::Define(
278       // Arg defs
279       {"x: N*T", "dy: out_types"},
280       // Ret val defs
281       {"dx: N*T"},
282       // Attr defs
283       {"T: type", "N: int", "out_types: list(type)"},
284       // Nodes
285       {
286         {{"dx"}, "_ListToArray", dys,
287          {{"T", "$T"}, {"N", "$N"}, {"Tin", "$out_types"}}}
288       });
289   // clang-format on
290   VLOG(1) << "ArrayToListGrad " << DebugString(*g);
291   return Status::OK();
292 }
293 REGISTER_OP_GRADIENT("_ArrayToList", ArrayToListGrad);
294 
ListToArrayGrad(const AttrSlice & attrs,FunctionDef * g)295 Status ListToArrayGrad(const AttrSlice& attrs, FunctionDef* g) {
296   // clang-format off
297   *g = FDH::Define(
298       // Arg defs
299       {"x: Tin", "dy: N*T"},
300       // Ret val defs
301       {"dx: Tin"},
302       // Attr defs
303       {"T: type", "N: int", "Tin: list(type)"},
304       // Nodes
305       {
306         {{"dx"}, "_ArrayToList", {"dy"},
307          {{"T", "$T"}, {"N", "$N"}, {"out_types", "$Tin"}}}
308       });
309   // clang-format on
310   VLOG(1) << "ListToArrayGrad " << DebugString(*g);
311   return Status::OK();
312 }
313 REGISTER_OP_GRADIENT("_ListToArray", ListToArrayGrad);
314 
FillGrad(const AttrSlice & attrs,FunctionDef * g)315 Status FillGrad(const AttrSlice& attrs, FunctionDef* g) {
316   *g = FDH::Define(
317       // Arg defs
318       {"dims: int32", "x: T", "dy: T"},
319       // Ret val defs
320       {"d_dims: int32", "dx: T"},
321       // Attr defs
322       {"T: type"},
323       // Nodes
324       {
325           {{"d_dims"}, "ZerosLike", {"dims"}, {{"T", DT_INT32}}},
326           FDH::Const("zero", 0),
327           {{"rank"}, "Rank", {"dy"}, {{"T", "$T"}}},
328           FDH::Const("one", 1),
329           {{"r"}, "Range", {"zero", "rank", "one"}, {}},
330           // dx = sum(dy)
331           {{"dx"}, "Sum", {"dy", "r"}, {{"T", "$T"}}},
332       });
333   VLOG(1) << "FillGrad " << DebugString(*g);
334   return Status::OK();
335 }
336 REGISTER_OP_GRADIENT("Fill", FillGrad);
337 
TransposeGrad(const AttrSlice & attrs,FunctionDef * g)338 Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
339   *g = FDH::Define(
340       // Arg defs
341       {"x: T", "p: int32", "dy: T"},
342       // Ret val defs
343       {"dx: T", "dp: int32"},
344       // Attr defs
345       {"T: type"},
346       // Nodes
347       {
348           {{"q"}, "InvertPermutation", {"p"}, {}},
349           {{"dx"}, "Transpose", {"dy", "q"}, {{"T", "$T"}}},
350           {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
351       });
352   VLOG(1) << "TransposeGrad " << DebugString(*g);
353   return Status::OK();
354 }
355 REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
356 
GatherNdGrad(const AttrSlice & attrs,FunctionDef * g)357 Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) {
358   // clang-format off
359   *g = FDH::Define(
360       // Arg defs
361       {"params: Tparams", "indices: Tindices", "doutput: Tparams"},
362       // Ret val defs
363       {"dparams: Tparams", "dindices: Tindices"},
364       // Attr defs
365       {"Tparams: type", "Tindices: type"},
366       // Nodes
367       {
368         {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}},
369         {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"},
370          {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}},
371         {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}},
372       });
373   // clang-format on
374   return Status::OK();
375 }
376 REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad);
377 
ConjugateTransposeGrad(const AttrSlice & attrs,FunctionDef * g)378 Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
379   *g = FDH::Define(
380       // Arg defs
381       {"x: T", "p: int32", "dy: T"},
382       // Ret val defs
383       {"dx: T", "dp: int32"},
384       // Attr defs
385       {"T: type"},
386       // Nodes
387       {
388           {{"q"}, "InvertPermutation", {"p"}, {}},
389           {{"dx"}, "ConjugateTranspose", {"dy", "q"}, {{"T", "$T"}}},
390           {{"dp"}, "ZerosLike", {"p"}, {{"T", DT_INT32}}},
391       });
392   VLOG(1) << "ConjugateTransposeGrad " << DebugString(*g);
393   return Status::OK();
394 }
395 REGISTER_OP_GRADIENT("ConjugateTranspose", ConjugateTransposeGrad);
396 
ReverseGrad(const AttrSlice & attrs,FunctionDef * g)397 Status ReverseGrad(const AttrSlice& attrs, FunctionDef* g) {
398   *g = FDH::Define(
399       // Arg defs
400       {"x: T", "d: bool", "dy: T"},
401       // Ret val defs
402       {"dx: T", "dd: bool"},
403       // Attr defs
404       {"T: type"},
405       // Nodes
406       {
407           {{"dx"}, "Reverse", {"dy", "d"}, {{"T", "$T"}}},
408           {{"dd"}, "ZerosLike", {"d"}, {{"T", DT_BOOL}}},
409       });
410   VLOG(1) << "ReverseGrad " << DebugString(*g);
411   return Status::OK();
412 }
413 REGISTER_OP_GRADIENT("Reverse", ReverseGrad);
414 
ReverseV2Grad(const AttrSlice & attrs,FunctionDef * g)415 Status ReverseV2Grad(const AttrSlice& attrs, FunctionDef* g) {
416   DataType itype;
417   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
418   if (itype != DT_INT32) {
419     return errors::Unimplemented(
420         "ReverseV2Grad for int64 index are not supported.");
421   }
422   *g = FDH::Define(
423       // Arg defs
424       {"x: T", "d: int32", "dy: T"},
425       // Ret val defs
426       {"dx: T", "dd: int32"},
427       // Attr defs
428       {"T: type", "Tidx: {int32, int64}"},
429       // Nodes
430       {
431           {{"dx"}, "ReverseV2", {"dy", "d"}, {{"T", "$T"}}},
432           {{"dd"}, "ZerosLike", {"d"}, {{"T", "$Tidx"}}},
433       });
434   VLOG(1) << "ReverseGrad " << DebugString(*g);
435   return Status::OK();
436 }
437 REGISTER_OP_GRADIENT("ReverseV2", ReverseV2Grad);
438 
SliceGrad(const AttrSlice & attrs,FunctionDef * g)439 Status SliceGrad(const AttrSlice& attrs, FunctionDef* g) {
440   DataType itype;
441   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
442   if (itype != DT_INT32) {
443     return errors::Unimplemented(
444         "SliceGrad for int64 index are not supported.");
445   }
446   *g = FDH::Define(
447       // Arg defs
448       {"x: T", "begin: int32", "size: int32", "dy: T"},
449       // Ret val defs
450       {"dx: T", "begin_grad: int32", "size_grad: int32"},
451       // Attr defs
452       {"T: type"},
453       // Nodes
454       {// paddings = concat(1, [begin, shape(x) - begin - size])
455        FDH::Const("one", 1),
456        {{"b1"}, "ExpandDims", {"begin", "one"}, {{"T", DT_INT32}}},
457        {{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
458        {{"xs_b"}, "Sub", {"xs", "begin"}, {{"T", DT_INT32}}},
459        {{"xs_b_s"}, "Sub", {"xs_b", "size"}, {{"T", DT_INT32}}},
460        {{"a1"}, "ExpandDims", {"xs_b_s", "one"}, {{"T", DT_INT32}}},
461        {{"paddings"},
462         "Concat",
463         {"one", "b1", "a1"},
464         {{"N", 2}, {"T", DT_INT32}}},
465        // dx = Pad(dy, paddings)
466        {{"dx"}, "Pad", {"dy", "paddings"}, {{"T", "$T"}}},
467        {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
468        {{"size_grad"}, "ZerosLike", {"size"}, {{"T", DT_INT32}}}});
469   VLOG(1) << "SliceGrad " << DebugString(*g);
470   return Status::OK();
471 }
472 REGISTER_OP_GRADIENT("Slice", SliceGrad);
473 
StridedSliceGrad(const AttrSlice & attrs,FunctionDef * g)474 Status StridedSliceGrad(const AttrSlice& attrs, FunctionDef* g) {
475   DataType itype;
476   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
477   if (itype != DT_INT32) {
478     return errors::Unimplemented(
479         "SliceGrad for int64 index are not supported.");
480   }
481 
482   *g = FDH::Define(
483       // Arg defs
484       {"x: T", "begin: int32", "end: int32", "stride: int32", "dy: T"},
485       // Ret val defs
486       {"dx: T", "begin_grad: int32", "end_grad: int32", "stride_grad: int32"},
487       // Attr defs
488       {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
489        "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
490       {// Nodes
491        {{{"xs"}, "Shape", {"x"}, {{"T", "$T"}}},
492         {{"dx"},
493          "StridedSliceGrad",
494          {"xs", "begin", "end", "stride", "dy"},
495          {{"T", "$T"},
496           {"Index", "$Index"},
497           {"begin_mask", "$begin_mask"},
498           {"end_mask", "$end_mask"},
499           {"ellipsis_mask", "$ellipsis_mask"},
500           {"new_axis_mask", "$new_axis_mask"},
501           {"shrink_axis_mask", "$shrink_axis_mask"}}},
502         {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
503         {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
504         {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}}}});
505 
506   VLOG(1) << "StridedSliceGrad " << DebugString(*g);
507   return Status::OK();
508 }
509 REGISTER_OP_GRADIENT("StridedSlice", StridedSliceGrad);
510 
StridedSliceGradGrad(const AttrSlice & attrs,FunctionDef * g)511 Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
512   DataType itype;
513   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Index", &itype));
514   if (itype != DT_INT32) {
515     return errors::Unimplemented(
516         "SliceGrad for int64 index are not supported.");
517   }
518 
519   // TODO(aselle): Shouldn't the int32 tensors return zeros of shape like
520   // dy_grad?
521   // I'm following slice's behavior for now.
522   *g = FDH::Define(
523       // Arg defs
524       {"shape: int32", "begin: int32", "end: int32", "stride: int32", "dy: T",
525        "grad: T"},
526       // Ret val defs
527       {"shape_grad: int32", "begin_grad: int32", "end_grad: int32",
528        "stride_grad: int32", "dy_grad: T"},
529       // Attr defs
530       {"T: type", "Index: {int32, int64}", "begin_mask: int", "end_mask: int",
531        "ellipsis_mask: int", "new_axis_mask: int", "shrink_axis_mask: int"},
532       {// Nodes
533        {{{"shape_grad"}, "ZerosLike", {"shape"}, {{"T", DT_INT32}}},
534         {{"begin_grad"}, "ZerosLike", {"begin"}, {{"T", DT_INT32}}},
535         {{"end_grad"}, "ZerosLike", {"end"}, {{"T", DT_INT32}}},
536         {{"stride_grad"}, "ZerosLike", {"stride"}, {{"T", DT_INT32}}},
537         {{"dy_grad"},
538          "StridedSlice",
539          {"grad", "begin", "end", "stride"},
540          {{"T", "$T"},
541           {"Index", "$Index"},
542           {"begin_mask", "$begin_mask"},
543           {"end_mask", "$end_mask"},
544           {"ellipsis_mask", "$ellipsis_mask"},
545           {"new_axis_mask", "$new_axis_mask"},
546           {"shrink_axis_mask", "$shrink_axis_mask"}}}}});
547 
548   VLOG(1) << "StridedSliceGrad " << DebugString(*g);
549   return Status::OK();
550 }
551 REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
552 
BroadcastToGrad(const AttrSlice & attrs,FunctionDef * g)553 Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
554   DataType itype;
555   TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
556   if (itype != DT_INT32) {
557     return errors::Unimplemented(
558         "BroadcastToGrad for int64 index are not supported.");
559   }
560   std::vector<FDH::Node> nodes = {
561       {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
562       {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
563       {{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
564       {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
565       {{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
566   *g = FDH::Define(
567       // Arg defs
568       {"x: T", "shape: int32", "dy: T"},
569       // Ret val defs
570       {"dx: T", "dshape: Tidx"},
571       // Attr defs
572       {{"T: type"}, {"Tidx: {int32, int64}"}},
573       // Nodes
574       nodes);
575   return Status::OK();
576 }
577 REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
578 
579 }  // end namespace tensorflow
580