• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18 
19 namespace tensorflow {
20 
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24 
ShapeOrHandleShape(InferenceContext * c,int input)25 static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
26   auto* handle_data = c->input_handle_shapes_and_types(input);
27   if (handle_data != nullptr && !handle_data->empty() &&
28       (*handle_data)[0].dtype != DT_INVALID) {
29     return (*handle_data)[0].shape;
30   }
31   return c->input(input);
32 }
33 
34 // Handle the gradient and, if <sparse>, indices inputs.
35 // <s> is an input+output parameter, containing the current known input shape to
36 // the gradient.
HandleGradAndIndicesInputs(InferenceContext * c,bool sparse,int grad_idx,ShapeHandle * s)37 static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
38                                          int grad_idx, ShapeHandle* s) {
39   ShapeHandle grad = ShapeOrHandleShape(c, grad_idx);
40   if (!sparse) {
41     TF_RETURN_IF_ERROR(c->Merge(*s, grad, s));
42     return Status::OK();
43   }
44   // Indices is a vector where indices.dim[0].rank == grad[0].rank.
45   ShapeHandle indices;
46   TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices));
47   DimensionHandle unused;
48   TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
49 
50   // Trailing part of grad matches trailing part of *s.
51   ShapeHandle grad_unknown_first;
52   TF_RETURN_IF_ERROR(
53       c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
54   TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
55 
56   return Status::OK();
57 }
58 
ApplyGradientDescentShapeFn(InferenceContext * c)59 static Status ApplyGradientDescentShapeFn(InferenceContext* c) {
60   ShapeHandle unused;
61   ShapeHandle s = ShapeOrHandleShape(c, 0);                  // var
62   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));  // alpha
63   TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));          // delta
64   if (c->num_outputs() > 0) {
65     c->set_output(0, s);
66   }
67   return Status::OK();
68 }
69 
70 REGISTER_OP("ApplyGradientDescent")
71     .Input("var: Ref(T)")
72     .Input("alpha: T")
73     .Input("delta: T")
74     .Output("out: Ref(T)")
75     .Attr("T: numbertype")
76     .Attr("use_locking: bool = false")
77     .SetShapeFn(ApplyGradientDescentShapeFn);
78 
79 REGISTER_OP("ResourceApplyGradientDescent")
80     .Input("var: resource")
81     .Input("alpha: T")
82     .Input("delta: T")
83     .Attr("T: numbertype")
84     .Attr("use_locking: bool = false")
85     .SetShapeFn(ApplyGradientDescentShapeFn);
86 
ApplyProximalGradientDescentShapeFn(InferenceContext * c,bool sparse)87 static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
88                                                   bool sparse) {
89   ShapeHandle unused;
90   ShapeHandle s = ShapeOrHandleShape(c, 0);                  // var
91   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));  // alpha
92   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));  // l1
93   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));  // l2
94   TF_RETURN_IF_ERROR(
95       HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
96   if (c->num_outputs() > 0) {
97     c->set_output(0, s);
98   }
99   return Status::OK();
100 }
101 
102 REGISTER_OP("ApplyProximalGradientDescent")
103     .Input("var: Ref(T)")
104     .Input("alpha: T")
105     .Input("l1: T")
106     .Input("l2: T")
107     .Input("delta: T")
108     .Output("out: Ref(T)")
109     .Attr("T: numbertype")
110     .Attr("use_locking: bool = false")
__anoncd5e4d380102(InferenceContext* c) 111     .SetShapeFn([](InferenceContext* c) {
112       return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
113     });
114 
115 REGISTER_OP("SparseApplyProximalGradientDescent")
116     .Input("var: Ref(T)")
117     .Input("alpha: T")
118     .Input("l1: T")
119     .Input("l2: T")
120     .Input("grad: T")
121     .Input("indices: Tindices")
122     .Output("out: Ref(T)")
123     .Attr("T: numbertype")
124     .Attr("Tindices: {int32, int64}")
125     .Attr("use_locking: bool = false")
__anoncd5e4d380202(InferenceContext* c) 126     .SetShapeFn([](InferenceContext* c) {
127       return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
128     });
129 
130 REGISTER_OP("ResourceApplyProximalGradientDescent")
131     .Input("var: resource")
132     .Input("alpha: T")
133     .Input("l1: T")
134     .Input("l2: T")
135     .Input("delta: T")
136     .Attr("T: numbertype")
137     .Attr("use_locking: bool = false")
__anoncd5e4d380302(InferenceContext* c) 138     .SetShapeFn([](InferenceContext* c) {
139       return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
140     });
141 
142 REGISTER_OP("ResourceSparseApplyProximalGradientDescent")
143     .Input("var: resource")
144     .Input("alpha: T")
145     .Input("l1: T")
146     .Input("l2: T")
147     .Input("grad: T")
148     .Input("indices: Tindices")
149     .Attr("T: numbertype")
150     .Attr("Tindices: {int32, int64}")
151     .Attr("use_locking: bool = false")
__anoncd5e4d380402(InferenceContext* c) 152     .SetShapeFn([](InferenceContext* c) {
153       return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
154     });
155 
ApplyAdadeltaShapeFn(InferenceContext * c,bool sparse)156 static Status ApplyAdadeltaShapeFn(InferenceContext* c, bool sparse) {
157   ShapeHandle unused;
158   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
159   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // accum
160   TF_RETURN_IF_ERROR(
161       c->Merge(s, ShapeOrHandleShape(c, 2), &s));            // accum update
162   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));  // lr
163   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));  // rho
164   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));  // epsilon
165   TF_RETURN_IF_ERROR(
166       HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
167   if (c->num_outputs() > 0) {
168     c->set_output(0, s);
169   }
170   return Status::OK();
171 }
172 
173 REGISTER_OP("ApplyAdadelta")
174     .Input("var: Ref(T)")
175     .Input("accum: Ref(T)")
176     .Input("accum_update: Ref(T)")
177     .Input("lr: T")
178     .Input("rho: T")
179     .Input("epsilon: T")
180     .Input("grad: T")
181     .Output("out: Ref(T)")
182     .Attr("T: numbertype")
183     .Attr("use_locking: bool = false")
__anoncd5e4d380502(InferenceContext* c) 184     .SetShapeFn([](InferenceContext* c) {
185       return ApplyAdadeltaShapeFn(c, false /* sparse */);
186     });
187 
188 REGISTER_OP("SparseApplyAdadelta")
189     .Input("var: Ref(T)")
190     .Input("accum: Ref(T)")
191     .Input("accum_update: Ref(T)")
192     .Input("lr: T")
193     .Input("rho: T")
194     .Input("epsilon: T")
195     .Input("grad: T")
196     .Input("indices: Tindices")
197     .Output("out: Ref(T)")
198     .Attr("T: numbertype")
199     .Attr("Tindices: {int32, int64}")
200     .Attr("use_locking: bool = false")
__anoncd5e4d380602(InferenceContext* c) 201     .SetShapeFn([](InferenceContext* c) {
202       return ApplyAdadeltaShapeFn(c, true /* sparse */);
203     });
204 
205 REGISTER_OP("ResourceApplyAdadelta")
206     .Input("var: resource")
207     .Input("accum: resource")
208     .Input("accum_update: resource")
209     .Input("lr: T")
210     .Input("rho: T")
211     .Input("epsilon: T")
212     .Input("grad: T")
213     .Attr("T: numbertype")
214     .Attr("use_locking: bool = false")
__anoncd5e4d380702(InferenceContext* c) 215     .SetShapeFn([](InferenceContext* c) {
216       return ApplyAdadeltaShapeFn(c, false /* sparse */);
217     });
218 
219 REGISTER_OP("ResourceSparseApplyAdadelta")
220     .Input("var: resource")
221     .Input("accum: resource")
222     .Input("accum_update: resource")
223     .Input("lr: T")
224     .Input("rho: T")
225     .Input("epsilon: T")
226     .Input("grad: T")
227     .Input("indices: Tindices")
228     .Attr("T: numbertype")
229     .Attr("Tindices: {int32, int64}")
230     .Attr("use_locking: bool = false")
__anoncd5e4d380802(InferenceContext* c) 231     .SetShapeFn([](InferenceContext* c) {
232       return ApplyAdadeltaShapeFn(c, true /* sparse */);
233     });
234 
ApplyAdagradShapeFn(InferenceContext * c,bool sparse)235 static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) {
236   ShapeHandle unused;
237   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
238   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // accum
239   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));       // lr
240   TF_RETURN_IF_ERROR(
241       HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
242   if (c->num_outputs() > 0) {
243     c->set_output(0, s);
244   }
245   return Status::OK();
246 }
247 
248 REGISTER_OP("ApplyAdagrad")
249     .Input("var: Ref(T)")
250     .Input("accum: Ref(T)")
251     .Input("lr: T")
252     .Input("grad: T")
253     .Output("out: Ref(T)")
254     .Attr("T: numbertype")
255     .Attr("use_locking: bool = false")
256     .Attr("update_slots: bool = true")
__anoncd5e4d380902(InferenceContext* c) 257     .SetShapeFn([](InferenceContext* c) {
258       return ApplyAdagradShapeFn(c, false /* sparse */);
259     });
260 
261 REGISTER_OP("ResourceApplyAdagrad")
262     .Input("var: resource")
263     .Input("accum: resource")
264     .Input("lr: T")
265     .Input("grad: T")
266     .Attr("T: numbertype")
267     .Attr("use_locking: bool = false")
268     .Attr("update_slots: bool = true")
__anoncd5e4d380a02(InferenceContext* c) 269     .SetShapeFn([](InferenceContext* c) {
270       return ApplyAdagradShapeFn(c, false /* sparse */);
271     });
272 
ApplyProximalAdagradShapeFn(InferenceContext * c,bool sparse)273 static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
274   ShapeHandle unused;
275   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
276   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // accum
277   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));       // lr
278   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // l1
279   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // l2
280   TF_RETURN_IF_ERROR(
281       HandleGradAndIndicesInputs(c, sparse, 5 /* grad_idx */, &s));
282   if (c->num_outputs() > 0) {
283     c->set_output(0, s);
284   }
285   return Status::OK();
286 }
287 
288 REGISTER_OP("ApplyProximalAdagrad")
289     .Input("var: Ref(T)")
290     .Input("accum: Ref(T)")
291     .Input("lr: T")
292     .Input("l1: T")
293     .Input("l2: T")
294     .Input("grad: T")
295     .Output("out: Ref(T)")
296     .Attr("T: numbertype")
297     .Attr("use_locking: bool = false")
__anoncd5e4d380b02(InferenceContext* c) 298     .SetShapeFn([](InferenceContext* c) {
299       return ApplyProximalAdagradShapeFn(c, false /* sparse */);
300     });
301 
302 REGISTER_OP("ResourceApplyProximalAdagrad")
303     .Input("var: resource")
304     .Input("accum: resource")
305     .Input("lr: T")
306     .Input("l1: T")
307     .Input("l2: T")
308     .Input("grad: T")
309     .Attr("T: numbertype")
310     .Attr("use_locking: bool = false")
__anoncd5e4d380c02(InferenceContext* c) 311     .SetShapeFn([](InferenceContext* c) {
312       return ApplyProximalAdagradShapeFn(c, false /* sparse */);
313     });
314 
315 REGISTER_OP("SparseApplyAdagrad")
316     .Input("var: Ref(T)")
317     .Input("accum: Ref(T)")
318     .Input("lr: T")
319     .Input("grad: T")
320     .Input("indices: Tindices")
321     .Output("out: Ref(T)")
322     .Attr("T: numbertype")
323     .Attr("Tindices: {int32, int64}")
324     .Attr("use_locking: bool = false")
325     .Attr("update_slots: bool = true")
__anoncd5e4d380d02(InferenceContext* c) 326     .SetShapeFn([](InferenceContext* c) {
327       return ApplyAdagradShapeFn(c, true /* sparse */);
328     });
329 
330 REGISTER_OP("ResourceSparseApplyAdagrad")
331     .Input("var: resource")
332     .Input("accum: resource")
333     .Input("lr: T")
334     .Input("grad: T")
335     .Input("indices: Tindices")
336     .Attr("T: numbertype")
337     .Attr("Tindices: {int32, int64}")
338     .Attr("use_locking: bool = false")
339     .Attr("update_slots: bool = true")
__anoncd5e4d380e02(InferenceContext* c) 340     .SetShapeFn([](InferenceContext* c) {
341       return ApplyAdagradShapeFn(c, true /* sparse */);
342     });
343 
ApplyAdagradDAShapeFn(InferenceContext * c,bool sparse)344 static Status ApplyAdagradDAShapeFn(InferenceContext* c, bool sparse) {
345   ShapeHandle unused;
346   ShapeHandle s = ShapeOrHandleShape(c, 0);  // var
347   TF_RETURN_IF_ERROR(
348       c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // grad_accumulator
349   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2),
350                               &s));  // gradient_squared_accumulator
351   TF_RETURN_IF_ERROR(
352       HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
353   int idx = sparse ? 5 : 4;
354   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // lr
355   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // l1
356   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // l2
357   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // global step
358   if (c->num_outputs() > 0) {
359     c->set_output(0, s);
360   }
361   return Status::OK();
362 }
363 
364 REGISTER_OP("ApplyAdagradDA")
365     .Input("var: Ref(T)")
366     .Input("gradient_accumulator: Ref(T)")
367     .Input("gradient_squared_accumulator: Ref(T)")
368     .Input("grad: T")
369     .Input("lr: T")
370     .Input("l1: T")
371     .Input("l2: T")
372     .Input("global_step: int64")
373     .Output("out: Ref(T)")
374     .Attr("T: numbertype")
375     .Attr("use_locking: bool = false")
__anoncd5e4d380f02(InferenceContext* c) 376     .SetShapeFn([](InferenceContext* c) {
377       return ApplyAdagradDAShapeFn(c, false /* sparse */);
378     });
379 
380 REGISTER_OP("SparseApplyAdagradDA")
381     .Input("var: Ref(T)")
382     .Input("gradient_accumulator: Ref(T)")
383     .Input("gradient_squared_accumulator: Ref(T)")
384     .Input("grad: T")
385     .Input("indices: Tindices")
386     .Input("lr: T")
387     .Input("l1: T")
388     .Input("l2: T")
389     .Input("global_step: int64")
390     .Output("out: Ref(T)")
391     .Attr("T: numbertype")
392     .Attr("Tindices: {int32, int64}")
393     .Attr("use_locking: bool = false")
__anoncd5e4d381002(InferenceContext* c) 394     .SetShapeFn([](InferenceContext* c) {
395       return ApplyAdagradDAShapeFn(c, true /* sparse */);
396     });
397 
398 REGISTER_OP("SparseApplyProximalAdagrad")
399     .Input("var: Ref(T)")
400     .Input("accum: Ref(T)")
401     .Input("lr: T")
402     .Input("l1: T")
403     .Input("l2: T")
404     .Input("grad: T")
405     .Input("indices: Tindices")
406     .Output("out: Ref(T)")
407     .Attr("T: numbertype")
408     .Attr("Tindices: {int32, int64}")
409     .Attr("use_locking: bool = false")
__anoncd5e4d381102(InferenceContext* c) 410     .SetShapeFn([](InferenceContext* c) {
411       return ApplyProximalAdagradShapeFn(c, true /* sparse */);
412     });
413 
414 REGISTER_OP("ResourceApplyAdagradDA")
415     .Input("var: resource")
416     .Input("gradient_accumulator: resource")
417     .Input("gradient_squared_accumulator: resource")
418     .Input("grad: T")
419     .Input("lr: T")
420     .Input("l1: T")
421     .Input("l2: T")
422     .Input("global_step: int64")
423     .Attr("T: numbertype")
424     .Attr("use_locking: bool = false")
__anoncd5e4d381202(InferenceContext* c) 425     .SetShapeFn([](InferenceContext* c) {
426       return ApplyAdagradDAShapeFn(c, false /* sparse */);
427     });
428 
429 REGISTER_OP("ResourceSparseApplyAdagradDA")
430     .Input("var: resource")
431     .Input("gradient_accumulator: resource")
432     .Input("gradient_squared_accumulator: resource")
433     .Input("grad: T")
434     .Input("indices: Tindices")
435     .Input("lr: T")
436     .Input("l1: T")
437     .Input("l2: T")
438     .Input("global_step: int64")
439     .Attr("T: numbertype")
440     .Attr("Tindices: {int32, int64}")
441     .Attr("use_locking: bool = false")
__anoncd5e4d381302(InferenceContext* c) 442     .SetShapeFn([](InferenceContext* c) {
443       return ApplyAdagradDAShapeFn(c, true /* sparse */);
444     });
445 
446 REGISTER_OP("ResourceSparseApplyProximalAdagrad")
447     .Input("var: resource")
448     .Input("accum: resource")
449     .Input("lr: T")
450     .Input("l1: T")
451     .Input("l2: T")
452     .Input("grad: T")
453     .Input("indices: Tindices")
454     .Attr("T: numbertype")
455     .Attr("Tindices: {int32, int64}")
456     .Attr("use_locking: bool = false")
__anoncd5e4d381402(InferenceContext* c) 457     .SetShapeFn([](InferenceContext* c) {
458       return ApplyProximalAdagradShapeFn(c, true /* sparse */);
459     });
460 
ApplyFtrlShapeFn(InferenceContext * c,bool sparse)461 static Status ApplyFtrlShapeFn(InferenceContext* c, bool sparse) {
462   ShapeHandle unused;
463   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
464   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // accum
465   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s));  // linear
466   TF_RETURN_IF_ERROR(
467       HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
468   int idx = sparse ? 5 : 4;
469   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // lr
470   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // l1
471   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // l2
472   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // lr_power
473   if (c->num_outputs() > 0) {
474     c->set_output(0, s);
475   }
476   return Status::OK();
477 }
478 
479 REGISTER_OP("ApplyFtrl")
480     .Input("var: Ref(T)")
481     .Input("accum: Ref(T)")
482     .Input("linear: Ref(T)")
483     .Input("grad: T")
484     .Input("lr: T")
485     .Input("l1: T")
486     .Input("l2: T")
487     .Input("lr_power: T")
488     .Output("out: Ref(T)")
489     .Attr("T: numbertype")
490     .Attr("use_locking: bool = false")
__anoncd5e4d381502(InferenceContext* c) 491     .SetShapeFn([](InferenceContext* c) {
492       return ApplyFtrlShapeFn(c, false /* sparse */);
493     });
494 
495 REGISTER_OP("SparseApplyFtrl")
496     .Input("var: Ref(T)")
497     .Input("accum: Ref(T)")
498     .Input("linear: Ref(T)")
499     .Input("grad: T")
500     .Input("indices: Tindices")
501     .Input("lr: T")
502     .Input("l1: T")
503     .Input("l2: T")
504     .Input("lr_power: T")
505     .Output("out: Ref(T)")
506     .Attr("T: numbertype")
507     .Attr("Tindices: {int32, int64}")
508     .Attr("use_locking: bool = false")
__anoncd5e4d381602(InferenceContext* c) 509     .SetShapeFn([](InferenceContext* c) {
510       return ApplyFtrlShapeFn(c, true /* sparse */);
511     });
512 
513 REGISTER_OP("ResourceApplyFtrl")
514     .Input("var: resource")
515     .Input("accum: resource")
516     .Input("linear: resource")
517     .Input("grad: T")
518     .Input("lr: T")
519     .Input("l1: T")
520     .Input("l2: T")
521     .Input("lr_power: T")
522     .Attr("T: numbertype")
523     .Attr("use_locking: bool = false")
__anoncd5e4d381702(InferenceContext* c) 524     .SetShapeFn([](InferenceContext* c) {
525       return ApplyFtrlShapeFn(c, false /* sparse */);
526     });
527 
528 REGISTER_OP("ResourceSparseApplyFtrl")
529     .Input("var: resource")
530     .Input("accum: resource")
531     .Input("linear: resource")
532     .Input("grad: T")
533     .Input("indices: Tindices")
534     .Input("lr: T")
535     .Input("l1: T")
536     .Input("l2: T")
537     .Input("lr_power: T")
538     .Attr("T: numbertype")
539     .Attr("Tindices: {int32, int64}")
540     .Attr("use_locking: bool = false")
__anoncd5e4d381802(InferenceContext* c) 541     .SetShapeFn([](InferenceContext* c) {
542       return ApplyFtrlShapeFn(c, true /* sparse */);
543     });
544 
545 REGISTER_OP("ApplyFtrlV2")
546     .Input("var: Ref(T)")
547     .Input("accum: Ref(T)")
548     .Input("linear: Ref(T)")
549     .Input("grad: T")
550     .Input("lr: T")
551     .Input("l1: T")
552     .Input("l2: T")
553     .Input("l2_shrinkage: T")
554     .Input("lr_power: T")
555     .Output("out: Ref(T)")
556     .Attr("T: numbertype")
557     .Attr("use_locking: bool = false")
__anoncd5e4d381902(InferenceContext* c) 558     .SetShapeFn([](InferenceContext* c) {
559       return ApplyFtrlShapeFn(c, false /* sparse */);
560     });
561 
562 REGISTER_OP("SparseApplyFtrlV2")
563     .Input("var: Ref(T)")
564     .Input("accum: Ref(T)")
565     .Input("linear: Ref(T)")
566     .Input("grad: T")
567     .Input("indices: Tindices")
568     .Input("lr: T")
569     .Input("l1: T")
570     .Input("l2: T")
571     .Input("l2_shrinkage: T")
572     .Input("lr_power: T")
573     .Output("out: Ref(T)")
574     .Attr("T: numbertype")
575     .Attr("Tindices: {int32, int64}")
576     .Attr("use_locking: bool = false")
__anoncd5e4d381a02(InferenceContext* c) 577     .SetShapeFn([](InferenceContext* c) {
578       return ApplyFtrlShapeFn(c, true /* sparse */);
579     });
580 
581 REGISTER_OP("ResourceApplyFtrlV2")
582     .Input("var: resource")
583     .Input("accum: resource")
584     .Input("linear: resource")
585     .Input("grad: T")
586     .Input("lr: T")
587     .Input("l1: T")
588     .Input("l2: T")
589     .Input("l2_shrinkage: T")
590     .Input("lr_power: T")
591     .Attr("T: numbertype")
592     .Attr("use_locking: bool = false")
__anoncd5e4d381b02(InferenceContext* c) 593     .SetShapeFn([](InferenceContext* c) {
594       return ApplyFtrlShapeFn(c, false /* sparse */);
595     });
596 
597 REGISTER_OP("ResourceSparseApplyFtrlV2")
598     .Input("var: resource")
599     .Input("accum: resource")
600     .Input("linear: resource")
601     .Input("grad: T")
602     .Input("indices: Tindices")
603     .Input("lr: T")
604     .Input("l1: T")
605     .Input("l2: T")
606     .Input("l2_shrinkage: T")
607     .Input("lr_power: T")
608     .Attr("T: numbertype")
609     .Attr("Tindices: {int32, int64}")
610     .Attr("use_locking: bool = false")
__anoncd5e4d381c02(InferenceContext* c) 611     .SetShapeFn([](InferenceContext* c) {
612       return ApplyFtrlShapeFn(c, true /* sparse */);
613     });
614 
ApplyMomentumShapeFn(InferenceContext * c,bool sparse)615 static Status ApplyMomentumShapeFn(InferenceContext* c, bool sparse) {
616   ShapeHandle unused;
617   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
618   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // accum
619   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));       // lr
620   TF_RETURN_IF_ERROR(
621       HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
622   int idx = sparse ? 5 : 4;
623   TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused));  // momentum
624   if (c->num_outputs() > 0) {
625     c->set_output(0, s);
626   }
627   return Status::OK();
628 }
629 
630 REGISTER_OP("ApplyMomentum")
631     .Input("var: Ref(T)")
632     .Input("accum: Ref(T)")
633     .Input("lr: T")
634     .Input("grad: T")
635     .Input("momentum: T")
636     .Output("out: Ref(T)")
637     .Attr("T: numbertype")
638     .Attr("use_locking: bool = false")
639     .Attr("use_nesterov: bool = false")
__anoncd5e4d381d02(InferenceContext* c) 640     .SetShapeFn([](InferenceContext* c) {
641       return ApplyMomentumShapeFn(c, false /* sparse */);
642     });
643 
644 REGISTER_OP("SparseApplyMomentum")
645     .Input("var: Ref(T)")
646     .Input("accum: Ref(T)")
647     .Input("lr: T")
648     .Input("grad: T")
649     .Input("indices: Tindices")
650     .Input("momentum: T")
651     .Output("out: Ref(T)")
652     .Attr("T: numbertype")
653     .Attr("Tindices: {int32, int64}")
654     .Attr("use_locking: bool = false")
655     .Attr("use_nesterov: bool = false")
__anoncd5e4d381e02(InferenceContext* c) 656     .SetShapeFn([](InferenceContext* c) {
657       return ApplyMomentumShapeFn(c, true /* sparse */);
658     });
659 
660 REGISTER_OP("ResourceApplyMomentum")
661     .Input("var: resource")
662     .Input("accum: resource")
663     .Input("lr: T")
664     .Input("grad: T")
665     .Input("momentum: T")
666     .Attr("T: numbertype")
667     .Attr("use_locking: bool = false")
668     .Attr("use_nesterov: bool = false")
__anoncd5e4d381f02(InferenceContext* c) 669     .SetShapeFn([](InferenceContext* c) {
670       return ApplyMomentumShapeFn(c, false /* sparse */);
671     });
672 
673 REGISTER_OP("ResourceSparseApplyMomentum")
674     .Input("var: resource")
675     .Input("accum: resource")
676     .Input("lr: T")
677     .Input("grad: T")
678     .Input("indices: Tindices")
679     .Input("momentum: T")
680     .Attr("T: numbertype")
681     .Attr("Tindices: {int32, int64}")
682     .Attr("use_locking: bool = false")
683     .Attr("use_nesterov: bool = false")
__anoncd5e4d382002(InferenceContext* c) 684     .SetShapeFn([](InferenceContext* c) {
685       return ApplyMomentumShapeFn(c, true /* sparse */);
686     });
687 
688 REGISTER_OP("ResourceApplyKerasMomentum")
689     .Input("var: resource")
690     .Input("accum: resource")
691     .Input("lr: T")
692     .Input("grad: T")
693     .Input("momentum: T")
694     .Attr("T: numbertype")
695     .Attr("use_locking: bool = false")
696     .Attr("use_nesterov: bool = false")
__anoncd5e4d382102(InferenceContext* c) 697     .SetShapeFn([](InferenceContext* c) {
698       return ApplyMomentumShapeFn(c, false /* sparse */);
699     });
700 
701 REGISTER_OP("ResourceSparseApplyKerasMomentum")
702     .Input("var: resource")
703     .Input("accum: resource")
704     .Input("lr: T")
705     .Input("grad: T")
706     .Input("indices: Tindices")
707     .Input("momentum: T")
708     .Attr("T: numbertype")
709     .Attr("Tindices: {int32, int64}")
710     .Attr("use_locking: bool = false")
711     .Attr("use_nesterov: bool = false")
__anoncd5e4d382202(InferenceContext* c) 712     .SetShapeFn([](InferenceContext* c) {
713       return ApplyMomentumShapeFn(c, true /* sparse */);
714     });
715 
ApplyAdamShapeFn(InferenceContext * c,bool sparse)716 static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) {
717   ShapeHandle unused;
718   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
719   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // m
720   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s));  // v
721   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // beta1_power
722   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // beta2_power
723   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // lr
724   TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));       // beta1
725   TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));       // beta2
726   TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));       // epsilon
727   TF_RETURN_IF_ERROR(
728       HandleGradAndIndicesInputs(c, sparse, 9 /* grad_idx */, &s));
729   if (c->num_outputs() > 0) {
730     c->set_output(0, s);
731   }
732   return Status::OK();
733 }
734 
735 REGISTER_OP("ApplyAdam")
736     .Input("var: Ref(T)")
737     .Input("m: Ref(T)")
738     .Input("v: Ref(T)")
739     .Input("beta1_power: T")
740     .Input("beta2_power: T")
741     .Input("lr: T")
742     .Input("beta1: T")
743     .Input("beta2: T")
744     .Input("epsilon: T")
745     .Input("grad: T")
746     .Output("out: Ref(T)")
747     .Attr("T: numbertype")
748     .Attr("use_locking: bool = false")
749     .Attr("use_nesterov: bool = false")
__anoncd5e4d382302(InferenceContext* c) 750     .SetShapeFn([](InferenceContext* c) {
751       return ApplyAdamShapeFn(c, false /* sparse */);
752     });
753 
754 REGISTER_OP("ResourceApplyAdam")
755     .Input("var: resource")
756     .Input("m: resource")
757     .Input("v: resource")
758     .Input("beta1_power: T")
759     .Input("beta2_power: T")
760     .Input("lr: T")
761     .Input("beta1: T")
762     .Input("beta2: T")
763     .Input("epsilon: T")
764     .Input("grad: T")
765     .Attr("T: numbertype")
766     .Attr("use_locking: bool = false")
767     .Attr("use_nesterov: bool = false")
__anoncd5e4d382402(InferenceContext* c) 768     .SetShapeFn([](InferenceContext* c) {
769       return ApplyAdamShapeFn(c, false /* sparse */);
770     });
771 
ApplyAdamWithAmsgradShapeFn(InferenceContext * c,bool sparse)772 static Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c, bool sparse) {
773   ShapeHandle unused;
774   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
775   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // m
776   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s));  // v
777   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 3), &s));  // vhat
778   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // beta1_power
779   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // beta2_power
780   TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));       // lr
781   TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));       // beta1
782   TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));       // beta2
783   TF_RETURN_IF_ERROR(c->WithRank(c->input(9), 0, &unused));       // epsilon
784   TF_RETURN_IF_ERROR(
785       HandleGradAndIndicesInputs(c, sparse, 10 /* grad_idx */, &s));
786   if (c->num_outputs() > 0) {
787     c->set_output(0, s);
788   }
789   return Status::OK();
790 }
791 
792 REGISTER_OP("ResourceApplyAdamWithAmsgrad")
793     .Input("var: resource")
794     .Input("m: resource")
795     .Input("v: resource")
796     .Input("vhat: resource")
797     .Input("beta1_power: T")
798     .Input("beta2_power: T")
799     .Input("lr: T")
800     .Input("beta1: T")
801     .Input("beta2: T")
802     .Input("epsilon: T")
803     .Input("grad: T")
804     .Attr("T: numbertype")
805     .Attr("use_locking: bool = false")
__anoncd5e4d382502(InferenceContext* c) 806     .SetShapeFn([](InferenceContext* c) {
807       return ApplyAdamWithAmsgradShapeFn(c, false /* sparse */);
808     });
809 
ApplyAdaMaxShapeFn(InferenceContext * c,bool sparse)810 static Status ApplyAdaMaxShapeFn(InferenceContext* c, bool sparse) {
811   ShapeHandle unused;
812   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
813   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // m
814   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s));  // v
815   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // beta1_power
816   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // lr
817   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // beta1
818   TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));       // beta2
819   TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));       // epsilon
820   TF_RETURN_IF_ERROR(
821       HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
822   if (c->num_outputs() > 0) {
823     c->set_output(0, s);
824   }
825   return Status::OK();
826 }
827 
828 REGISTER_OP("ApplyAdaMax")
829     .Input("var: Ref(T)")
830     .Input("m: Ref(T)")
831     .Input("v: Ref(T)")
832     .Input("beta1_power: T")
833     .Input("lr: T")
834     .Input("beta1: T")
835     .Input("beta2: T")
836     .Input("epsilon: T")
837     .Input("grad: T")
838     .Output("out: Ref(T)")
839     .Attr("T: numbertype")
840     .Attr("use_locking: bool = false")
__anoncd5e4d382602(InferenceContext* c) 841     .SetShapeFn([](InferenceContext* c) {
842       return ApplyAdaMaxShapeFn(c, false /* sparse */);
843     });
844 
845 REGISTER_OP("ResourceApplyAdaMax")
846     .Input("var: resource")
847     .Input("m: resource")
848     .Input("v: resource")
849     .Input("beta1_power: T")
850     .Input("lr: T")
851     .Input("beta1: T")
852     .Input("beta2: T")
853     .Input("epsilon: T")
854     .Input("grad: T")
855     .Attr("T: numbertype")
856     .Attr("use_locking: bool = false")
__anoncd5e4d382702(InferenceContext* c) 857     .SetShapeFn([](InferenceContext* c) {
858       return ApplyAdaMaxShapeFn(c, false /* sparse */);
859     });
860 
ApplyRMSPropShapeFn(InferenceContext * c,bool sparse)861 static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) {
862   ShapeHandle unused;
863   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
864   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // ms
865   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s));  // mom
866   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // lr
867   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // rho
868   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // momentum
869   TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));       // epsilon
870   TF_RETURN_IF_ERROR(
871       HandleGradAndIndicesInputs(c, sparse, 7 /* grad_idx */, &s));
872   if (c->num_outputs() > 0) {
873     c->set_output(0, s);
874   }
875   return Status::OK();
876 }
877 
ApplyCenteredRMSPropShapeFn(InferenceContext * c,bool sparse)878 static Status ApplyCenteredRMSPropShapeFn(InferenceContext* c, bool sparse) {
879   ShapeHandle unused;
880   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
881   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // ms
882   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s));  // mg
883   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 3), &s));  // mom
884   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // lr
885   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // rho
886   TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));       // momentum
887   TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));       // epsilon
888   TF_RETURN_IF_ERROR(
889       HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
890   if (c->num_outputs() > 0) {
891     c->set_output(0, s);
892   }
893   return Status::OK();
894 }
895 
896 REGISTER_OP("ApplyRMSProp")
897     .Input("var: Ref(T)")
898     .Input("ms: Ref(T)")
899     .Input("mom: Ref(T)")
900     .Input("lr: T")
901     .Input("rho: T")
902     .Input("momentum: T")
903     .Input("epsilon: T")
904     .Input("grad: T")
905     .Output("out: Ref(T)")
906     .Attr("T: numbertype")
907     .Attr("use_locking: bool = false")
__anoncd5e4d382802(InferenceContext* c) 908     .SetShapeFn([](InferenceContext* c) {
909       return ApplyRMSPropShapeFn(c, false /* sparse */);
910     });
911 
912 REGISTER_OP("ApplyCenteredRMSProp")
913     .Input("var: Ref(T)")
914     .Input("mg: Ref(T)")
915     .Input("ms: Ref(T)")
916     .Input("mom: Ref(T)")
917     .Input("lr: T")
918     .Input("rho: T")
919     .Input("momentum: T")
920     .Input("epsilon: T")
921     .Input("grad: T")
922     .Output("out: Ref(T)")
923     .Attr("T: numbertype")
924     .Attr("use_locking: bool = false")
__anoncd5e4d382902(InferenceContext* c) 925     .SetShapeFn([](InferenceContext* c) {
926       return ApplyCenteredRMSPropShapeFn(c, false /* sparse */);
927     });
928 
929 REGISTER_OP("SparseApplyRMSProp")
930     .Input("var: Ref(T)")
931     .Input("ms: Ref(T)")
932     .Input("mom: Ref(T)")
933     .Input("lr: T")
934     .Input("rho: T")
935     .Input("momentum: T")
936     .Input("epsilon: T")
937     .Input("grad: T")
938     .Input("indices: Tindices")
939     .Output("out: Ref(T)")
940     .Attr("T: numbertype")
941     .Attr("Tindices: {int32, int64}")
942     .Attr("use_locking: bool = false")
__anoncd5e4d382a02(InferenceContext* c) 943     .SetShapeFn([](InferenceContext* c) {
944       return ApplyRMSPropShapeFn(c, true /* sparse */);
945     });
946 
947 REGISTER_OP("SparseApplyCenteredRMSProp")
948     .Input("var: Ref(T)")
949     .Input("mg: Ref(T)")
950     .Input("ms: Ref(T)")
951     .Input("mom: Ref(T)")
952     .Input("lr: T")
953     .Input("rho: T")
954     .Input("momentum: T")
955     .Input("epsilon: T")
956     .Input("grad: T")
957     .Input("indices: Tindices")
958     .Output("out: Ref(T)")
959     .Attr("T: numbertype")
960     .Attr("Tindices: {int32, int64}")
961     .Attr("use_locking: bool = false")
__anoncd5e4d382b02(InferenceContext* c) 962     .SetShapeFn([](InferenceContext* c) {
963       return ApplyCenteredRMSPropShapeFn(c, true /* sparse */);
964     });
965 
966 REGISTER_OP("ResourceApplyRMSProp")
967     .Input("var: resource")
968     .Input("ms: resource")
969     .Input("mom: resource")
970     .Input("lr: T")
971     .Input("rho: T")
972     .Input("momentum: T")
973     .Input("epsilon: T")
974     .Input("grad: T")
975     .Attr("T: numbertype")
976     .Attr("use_locking: bool = false")
__anoncd5e4d382c02(InferenceContext* c) 977     .SetShapeFn([](InferenceContext* c) {
978       return ApplyRMSPropShapeFn(c, false /* sparse */);
979     });
980 
981 REGISTER_OP("ResourceApplyCenteredRMSProp")
982     .Input("var: resource")
983     .Input("mg: resource")
984     .Input("ms: resource")
985     .Input("mom: resource")
986     .Input("lr: T")
987     .Input("rho: T")
988     .Input("momentum: T")
989     .Input("epsilon: T")
990     .Input("grad: T")
991     .Attr("T: numbertype")
992     .Attr("use_locking: bool = false")
__anoncd5e4d382d02(InferenceContext* c) 993     .SetShapeFn([](InferenceContext* c) {
994       return ApplyCenteredRMSPropShapeFn(c, false /* sparse */);
995     });
996 
997 REGISTER_OP("ResourceSparseApplyRMSProp")
998     .Input("var: resource")
999     .Input("ms: resource")
1000     .Input("mom: resource")
1001     .Input("lr: T")
1002     .Input("rho: T")
1003     .Input("momentum: T")
1004     .Input("epsilon: T")
1005     .Input("grad: T")
1006     .Input("indices: Tindices")
1007     .Attr("T: numbertype")
1008     .Attr("Tindices: {int32, int64}")
1009     .Attr("use_locking: bool = false")
__anoncd5e4d382e02(InferenceContext* c) 1010     .SetShapeFn([](InferenceContext* c) {
1011       return ApplyRMSPropShapeFn(c, true /* sparse */);
1012     });
1013 
1014 REGISTER_OP("ResourceSparseApplyCenteredRMSProp")
1015     .Input("var: resource")
1016     .Input("mg: resource")
1017     .Input("ms: resource")
1018     .Input("mom: resource")
1019     .Input("lr: T")
1020     .Input("rho: T")
1021     .Input("momentum: T")
1022     .Input("epsilon: T")
1023     .Input("grad: T")
1024     .Input("indices: Tindices")
1025     .Attr("T: numbertype")
1026     .Attr("Tindices: {int32, int64}")
1027     .Attr("use_locking: bool = false")
__anoncd5e4d382f02(InferenceContext* c) 1028     .SetShapeFn([](InferenceContext* c) {
1029       return ApplyCenteredRMSPropShapeFn(c, true /* sparse */);
1030     });
1031 
ApplyAddSignShapeFn(InferenceContext * c,bool sparse)1032 static Status ApplyAddSignShapeFn(InferenceContext* c, bool sparse) {
1033   ShapeHandle unused;
1034   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
1035   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // m
1036   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));       // lr
1037   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // alpha
1038   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // sign_decay
1039   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // beta
1040   TF_RETURN_IF_ERROR(
1041       HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
1042   if (c->num_outputs() > 0) {
1043     c->set_output(0, s);
1044   }
1045   return Status::OK();
1046 }
1047 
1048 REGISTER_OP("ApplyAddSign")
1049     .Input("var: Ref(T)")
1050     .Input("m: Ref(T)")
1051     .Input("lr: T")
1052     .Input("alpha: T")
1053     .Input("sign_decay: T")
1054     .Input("beta: T")
1055     .Input("grad: T")
1056     .Output("out: Ref(T)")
1057     .Attr("T: numbertype")
1058     .Attr("use_locking: bool = false")
__anoncd5e4d383002(InferenceContext* c) 1059     .SetShapeFn([](InferenceContext* c) {
1060       return ApplyAddSignShapeFn(c, /*sparse=*/false);
1061     });
1062 
1063 REGISTER_OP("ResourceApplyAddSign")
1064     .Input("var: resource")
1065     .Input("m: resource")
1066     .Input("lr: T")
1067     .Input("alpha: T")
1068     .Input("sign_decay: T")
1069     .Input("beta: T")
1070     .Input("grad: T")
1071     .Attr("T: numbertype")
1072     .Attr("use_locking: bool = false")
__anoncd5e4d383102(InferenceContext* c) 1073     .SetShapeFn([](InferenceContext* c) {
1074       return ApplyAddSignShapeFn(c, /*sparse=*/false);
1075     });
1076 
ApplyPowerSignShapeFn(InferenceContext * c,bool sparse)1077 static Status ApplyPowerSignShapeFn(InferenceContext* c, bool sparse) {
1078   ShapeHandle unused;
1079   ShapeHandle s = ShapeOrHandleShape(c, 0);                       // var
1080   TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s));  // m
1081   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));       // lr
1082   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));       // logbase
1083   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));       // sign_delay
1084   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));       // beta
1085   TF_RETURN_IF_ERROR(
1086       HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
1087   if (c->num_outputs() > 0) {
1088     c->set_output(0, s);
1089   }
1090   return Status::OK();
1091 }
1092 
1093 REGISTER_OP("ApplyPowerSign")
1094     .Input("var: Ref(T)")
1095     .Input("m: Ref(T)")
1096     .Input("lr: T")
1097     .Input("logbase: T")
1098     .Input("sign_decay: T")
1099     .Input("beta: T")
1100     .Input("grad: T")
1101     .Output("out: Ref(T)")
1102     .Attr("T: numbertype")
1103     .Attr("use_locking: bool = false")
__anoncd5e4d383202(InferenceContext* c) 1104     .SetShapeFn([](InferenceContext* c) {
1105       return ApplyPowerSignShapeFn(c, /*sparse=*/false);
1106     });
1107 
1108 REGISTER_OP("ResourceApplyPowerSign")
1109     .Input("var: resource")
1110     .Input("m: resource")
1111     .Input("lr: T")
1112     .Input("logbase: T")
1113     .Input("sign_decay: T")
1114     .Input("beta: T")
1115     .Input("grad: T")
1116     .Attr("T: numbertype")
1117     .Attr("use_locking: bool = false")
__anoncd5e4d383302(InferenceContext* c) 1118     .SetShapeFn([](InferenceContext* c) {
1119       return ApplyPowerSignShapeFn(c, /*sparse=*/false);
1120     });
1121 
1122 }  // namespace tensorflow
1123