1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/op.h"
17 #include "tensorflow/core/framework/shape_inference.h"
18
19 namespace tensorflow {
20
21 using shape_inference::DimensionHandle;
22 using shape_inference::InferenceContext;
23 using shape_inference::ShapeHandle;
24
ShapeOrHandleShape(InferenceContext * c,int input)25 static ShapeHandle ShapeOrHandleShape(InferenceContext* c, int input) {
26 auto* handle_data = c->input_handle_shapes_and_types(input);
27 if (handle_data != nullptr && !handle_data->empty() &&
28 (*handle_data)[0].dtype != DT_INVALID) {
29 return (*handle_data)[0].shape;
30 }
31 return c->input(input);
32 }
33
34 // Handle the gradient and, if <sparse>, indices inputs.
35 // <s> is an input+output parameter, containing the current known input shape to
36 // the gradient.
HandleGradAndIndicesInputs(InferenceContext * c,bool sparse,int grad_idx,ShapeHandle * s)37 static Status HandleGradAndIndicesInputs(InferenceContext* c, bool sparse,
38 int grad_idx, ShapeHandle* s) {
39 ShapeHandle grad = ShapeOrHandleShape(c, grad_idx);
40 if (!sparse) {
41 TF_RETURN_IF_ERROR(c->Merge(*s, grad, s));
42 return Status::OK();
43 }
44 // Indices is a vector where indices.dim[0].rank == grad[0].rank.
45 ShapeHandle indices;
46 TF_RETURN_IF_ERROR(c->WithRank(c->input(grad_idx + 1), 1, &indices));
47 DimensionHandle unused;
48 TF_RETURN_IF_ERROR(c->Merge(c->Dim(indices, 0), c->Dim(grad, 0), &unused));
49
50 // Trailing part of grad matches trailing part of *s.
51 ShapeHandle grad_unknown_first;
52 TF_RETURN_IF_ERROR(
53 c->ReplaceDim(grad, 0, c->UnknownDim(), &grad_unknown_first));
54 TF_RETURN_IF_ERROR(c->Merge(*s, grad_unknown_first, s));
55
56 return Status::OK();
57 }
58
ApplyGradientDescentShapeFn(InferenceContext * c)59 static Status ApplyGradientDescentShapeFn(InferenceContext* c) {
60 ShapeHandle unused;
61 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
62 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
63 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // delta
64 if (c->num_outputs() > 0) {
65 c->set_output(0, s);
66 }
67 return Status::OK();
68 }
69
70 REGISTER_OP("ApplyGradientDescent")
71 .Input("var: Ref(T)")
72 .Input("alpha: T")
73 .Input("delta: T")
74 .Output("out: Ref(T)")
75 .Attr("T: numbertype")
76 .Attr("use_locking: bool = false")
77 .SetShapeFn(ApplyGradientDescentShapeFn);
78
79 REGISTER_OP("ResourceApplyGradientDescent")
80 .Input("var: resource")
81 .Input("alpha: T")
82 .Input("delta: T")
83 .Attr("T: numbertype")
84 .Attr("use_locking: bool = false")
85 .SetShapeFn(ApplyGradientDescentShapeFn);
86
ApplyProximalGradientDescentShapeFn(InferenceContext * c,bool sparse)87 static Status ApplyProximalGradientDescentShapeFn(InferenceContext* c,
88 bool sparse) {
89 ShapeHandle unused;
90 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
91 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); // alpha
92 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // l1
93 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l2
94 TF_RETURN_IF_ERROR(
95 HandleGradAndIndicesInputs(c, sparse, 4 /* grad_idx */, &s));
96 if (c->num_outputs() > 0) {
97 c->set_output(0, s);
98 }
99 return Status::OK();
100 }
101
102 REGISTER_OP("ApplyProximalGradientDescent")
103 .Input("var: Ref(T)")
104 .Input("alpha: T")
105 .Input("l1: T")
106 .Input("l2: T")
107 .Input("delta: T")
108 .Output("out: Ref(T)")
109 .Attr("T: numbertype")
110 .Attr("use_locking: bool = false")
__anoncd5e4d380102(InferenceContext* c) 111 .SetShapeFn([](InferenceContext* c) {
112 return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
113 });
114
115 REGISTER_OP("SparseApplyProximalGradientDescent")
116 .Input("var: Ref(T)")
117 .Input("alpha: T")
118 .Input("l1: T")
119 .Input("l2: T")
120 .Input("grad: T")
121 .Input("indices: Tindices")
122 .Output("out: Ref(T)")
123 .Attr("T: numbertype")
124 .Attr("Tindices: {int32, int64}")
125 .Attr("use_locking: bool = false")
__anoncd5e4d380202(InferenceContext* c) 126 .SetShapeFn([](InferenceContext* c) {
127 return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
128 });
129
130 REGISTER_OP("ResourceApplyProximalGradientDescent")
131 .Input("var: resource")
132 .Input("alpha: T")
133 .Input("l1: T")
134 .Input("l2: T")
135 .Input("delta: T")
136 .Attr("T: numbertype")
137 .Attr("use_locking: bool = false")
__anoncd5e4d380302(InferenceContext* c) 138 .SetShapeFn([](InferenceContext* c) {
139 return ApplyProximalGradientDescentShapeFn(c, false /* sparse */);
140 });
141
142 REGISTER_OP("ResourceSparseApplyProximalGradientDescent")
143 .Input("var: resource")
144 .Input("alpha: T")
145 .Input("l1: T")
146 .Input("l2: T")
147 .Input("grad: T")
148 .Input("indices: Tindices")
149 .Attr("T: numbertype")
150 .Attr("Tindices: {int32, int64}")
151 .Attr("use_locking: bool = false")
__anoncd5e4d380402(InferenceContext* c) 152 .SetShapeFn([](InferenceContext* c) {
153 return ApplyProximalGradientDescentShapeFn(c, true /* sparse */);
154 });
155
ApplyAdadeltaShapeFn(InferenceContext * c,bool sparse)156 static Status ApplyAdadeltaShapeFn(InferenceContext* c, bool sparse) {
157 ShapeHandle unused;
158 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
159 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
160 TF_RETURN_IF_ERROR(
161 c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // accum update
162 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
163 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
164 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // epsilon
165 TF_RETURN_IF_ERROR(
166 HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
167 if (c->num_outputs() > 0) {
168 c->set_output(0, s);
169 }
170 return Status::OK();
171 }
172
173 REGISTER_OP("ApplyAdadelta")
174 .Input("var: Ref(T)")
175 .Input("accum: Ref(T)")
176 .Input("accum_update: Ref(T)")
177 .Input("lr: T")
178 .Input("rho: T")
179 .Input("epsilon: T")
180 .Input("grad: T")
181 .Output("out: Ref(T)")
182 .Attr("T: numbertype")
183 .Attr("use_locking: bool = false")
__anoncd5e4d380502(InferenceContext* c) 184 .SetShapeFn([](InferenceContext* c) {
185 return ApplyAdadeltaShapeFn(c, false /* sparse */);
186 });
187
188 REGISTER_OP("SparseApplyAdadelta")
189 .Input("var: Ref(T)")
190 .Input("accum: Ref(T)")
191 .Input("accum_update: Ref(T)")
192 .Input("lr: T")
193 .Input("rho: T")
194 .Input("epsilon: T")
195 .Input("grad: T")
196 .Input("indices: Tindices")
197 .Output("out: Ref(T)")
198 .Attr("T: numbertype")
199 .Attr("Tindices: {int32, int64}")
200 .Attr("use_locking: bool = false")
__anoncd5e4d380602(InferenceContext* c) 201 .SetShapeFn([](InferenceContext* c) {
202 return ApplyAdadeltaShapeFn(c, true /* sparse */);
203 });
204
205 REGISTER_OP("ResourceApplyAdadelta")
206 .Input("var: resource")
207 .Input("accum: resource")
208 .Input("accum_update: resource")
209 .Input("lr: T")
210 .Input("rho: T")
211 .Input("epsilon: T")
212 .Input("grad: T")
213 .Attr("T: numbertype")
214 .Attr("use_locking: bool = false")
__anoncd5e4d380702(InferenceContext* c) 215 .SetShapeFn([](InferenceContext* c) {
216 return ApplyAdadeltaShapeFn(c, false /* sparse */);
217 });
218
219 REGISTER_OP("ResourceSparseApplyAdadelta")
220 .Input("var: resource")
221 .Input("accum: resource")
222 .Input("accum_update: resource")
223 .Input("lr: T")
224 .Input("rho: T")
225 .Input("epsilon: T")
226 .Input("grad: T")
227 .Input("indices: Tindices")
228 .Attr("T: numbertype")
229 .Attr("Tindices: {int32, int64}")
230 .Attr("use_locking: bool = false")
__anoncd5e4d380802(InferenceContext* c) 231 .SetShapeFn([](InferenceContext* c) {
232 return ApplyAdadeltaShapeFn(c, true /* sparse */);
233 });
234
ApplyAdagradShapeFn(InferenceContext * c,bool sparse)235 static Status ApplyAdagradShapeFn(InferenceContext* c, bool sparse) {
236 ShapeHandle unused;
237 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
238 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
239 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
240 TF_RETURN_IF_ERROR(
241 HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
242 if (c->num_outputs() > 0) {
243 c->set_output(0, s);
244 }
245 return Status::OK();
246 }
247
248 REGISTER_OP("ApplyAdagrad")
249 .Input("var: Ref(T)")
250 .Input("accum: Ref(T)")
251 .Input("lr: T")
252 .Input("grad: T")
253 .Output("out: Ref(T)")
254 .Attr("T: numbertype")
255 .Attr("use_locking: bool = false")
256 .Attr("update_slots: bool = true")
__anoncd5e4d380902(InferenceContext* c) 257 .SetShapeFn([](InferenceContext* c) {
258 return ApplyAdagradShapeFn(c, false /* sparse */);
259 });
260
261 REGISTER_OP("ResourceApplyAdagrad")
262 .Input("var: resource")
263 .Input("accum: resource")
264 .Input("lr: T")
265 .Input("grad: T")
266 .Attr("T: numbertype")
267 .Attr("use_locking: bool = false")
268 .Attr("update_slots: bool = true")
__anoncd5e4d380a02(InferenceContext* c) 269 .SetShapeFn([](InferenceContext* c) {
270 return ApplyAdagradShapeFn(c, false /* sparse */);
271 });
272
ApplyProximalAdagradShapeFn(InferenceContext * c,bool sparse)273 static Status ApplyProximalAdagradShapeFn(InferenceContext* c, bool sparse) {
274 ShapeHandle unused;
275 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
276 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
277 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
278 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // l1
279 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // l2
280 TF_RETURN_IF_ERROR(
281 HandleGradAndIndicesInputs(c, sparse, 5 /* grad_idx */, &s));
282 if (c->num_outputs() > 0) {
283 c->set_output(0, s);
284 }
285 return Status::OK();
286 }
287
288 REGISTER_OP("ApplyProximalAdagrad")
289 .Input("var: Ref(T)")
290 .Input("accum: Ref(T)")
291 .Input("lr: T")
292 .Input("l1: T")
293 .Input("l2: T")
294 .Input("grad: T")
295 .Output("out: Ref(T)")
296 .Attr("T: numbertype")
297 .Attr("use_locking: bool = false")
__anoncd5e4d380b02(InferenceContext* c) 298 .SetShapeFn([](InferenceContext* c) {
299 return ApplyProximalAdagradShapeFn(c, false /* sparse */);
300 });
301
302 REGISTER_OP("ResourceApplyProximalAdagrad")
303 .Input("var: resource")
304 .Input("accum: resource")
305 .Input("lr: T")
306 .Input("l1: T")
307 .Input("l2: T")
308 .Input("grad: T")
309 .Attr("T: numbertype")
310 .Attr("use_locking: bool = false")
__anoncd5e4d380c02(InferenceContext* c) 311 .SetShapeFn([](InferenceContext* c) {
312 return ApplyProximalAdagradShapeFn(c, false /* sparse */);
313 });
314
315 REGISTER_OP("SparseApplyAdagrad")
316 .Input("var: Ref(T)")
317 .Input("accum: Ref(T)")
318 .Input("lr: T")
319 .Input("grad: T")
320 .Input("indices: Tindices")
321 .Output("out: Ref(T)")
322 .Attr("T: numbertype")
323 .Attr("Tindices: {int32, int64}")
324 .Attr("use_locking: bool = false")
325 .Attr("update_slots: bool = true")
__anoncd5e4d380d02(InferenceContext* c) 326 .SetShapeFn([](InferenceContext* c) {
327 return ApplyAdagradShapeFn(c, true /* sparse */);
328 });
329
330 REGISTER_OP("ResourceSparseApplyAdagrad")
331 .Input("var: resource")
332 .Input("accum: resource")
333 .Input("lr: T")
334 .Input("grad: T")
335 .Input("indices: Tindices")
336 .Attr("T: numbertype")
337 .Attr("Tindices: {int32, int64}")
338 .Attr("use_locking: bool = false")
339 .Attr("update_slots: bool = true")
__anoncd5e4d380e02(InferenceContext* c) 340 .SetShapeFn([](InferenceContext* c) {
341 return ApplyAdagradShapeFn(c, true /* sparse */);
342 });
343
ApplyAdagradDAShapeFn(InferenceContext * c,bool sparse)344 static Status ApplyAdagradDAShapeFn(InferenceContext* c, bool sparse) {
345 ShapeHandle unused;
346 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
347 TF_RETURN_IF_ERROR(
348 c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // grad_accumulator
349 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2),
350 &s)); // gradient_squared_accumulator
351 TF_RETURN_IF_ERROR(
352 HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
353 int idx = sparse ? 5 : 4;
354 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
355 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
356 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
357 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // global step
358 if (c->num_outputs() > 0) {
359 c->set_output(0, s);
360 }
361 return Status::OK();
362 }
363
364 REGISTER_OP("ApplyAdagradDA")
365 .Input("var: Ref(T)")
366 .Input("gradient_accumulator: Ref(T)")
367 .Input("gradient_squared_accumulator: Ref(T)")
368 .Input("grad: T")
369 .Input("lr: T")
370 .Input("l1: T")
371 .Input("l2: T")
372 .Input("global_step: int64")
373 .Output("out: Ref(T)")
374 .Attr("T: numbertype")
375 .Attr("use_locking: bool = false")
__anoncd5e4d380f02(InferenceContext* c) 376 .SetShapeFn([](InferenceContext* c) {
377 return ApplyAdagradDAShapeFn(c, false /* sparse */);
378 });
379
380 REGISTER_OP("SparseApplyAdagradDA")
381 .Input("var: Ref(T)")
382 .Input("gradient_accumulator: Ref(T)")
383 .Input("gradient_squared_accumulator: Ref(T)")
384 .Input("grad: T")
385 .Input("indices: Tindices")
386 .Input("lr: T")
387 .Input("l1: T")
388 .Input("l2: T")
389 .Input("global_step: int64")
390 .Output("out: Ref(T)")
391 .Attr("T: numbertype")
392 .Attr("Tindices: {int32, int64}")
393 .Attr("use_locking: bool = false")
__anoncd5e4d381002(InferenceContext* c) 394 .SetShapeFn([](InferenceContext* c) {
395 return ApplyAdagradDAShapeFn(c, true /* sparse */);
396 });
397
398 REGISTER_OP("SparseApplyProximalAdagrad")
399 .Input("var: Ref(T)")
400 .Input("accum: Ref(T)")
401 .Input("lr: T")
402 .Input("l1: T")
403 .Input("l2: T")
404 .Input("grad: T")
405 .Input("indices: Tindices")
406 .Output("out: Ref(T)")
407 .Attr("T: numbertype")
408 .Attr("Tindices: {int32, int64}")
409 .Attr("use_locking: bool = false")
__anoncd5e4d381102(InferenceContext* c) 410 .SetShapeFn([](InferenceContext* c) {
411 return ApplyProximalAdagradShapeFn(c, true /* sparse */);
412 });
413
414 REGISTER_OP("ResourceApplyAdagradDA")
415 .Input("var: resource")
416 .Input("gradient_accumulator: resource")
417 .Input("gradient_squared_accumulator: resource")
418 .Input("grad: T")
419 .Input("lr: T")
420 .Input("l1: T")
421 .Input("l2: T")
422 .Input("global_step: int64")
423 .Attr("T: numbertype")
424 .Attr("use_locking: bool = false")
__anoncd5e4d381202(InferenceContext* c) 425 .SetShapeFn([](InferenceContext* c) {
426 return ApplyAdagradDAShapeFn(c, false /* sparse */);
427 });
428
429 REGISTER_OP("ResourceSparseApplyAdagradDA")
430 .Input("var: resource")
431 .Input("gradient_accumulator: resource")
432 .Input("gradient_squared_accumulator: resource")
433 .Input("grad: T")
434 .Input("indices: Tindices")
435 .Input("lr: T")
436 .Input("l1: T")
437 .Input("l2: T")
438 .Input("global_step: int64")
439 .Attr("T: numbertype")
440 .Attr("Tindices: {int32, int64}")
441 .Attr("use_locking: bool = false")
__anoncd5e4d381302(InferenceContext* c) 442 .SetShapeFn([](InferenceContext* c) {
443 return ApplyAdagradDAShapeFn(c, true /* sparse */);
444 });
445
446 REGISTER_OP("ResourceSparseApplyProximalAdagrad")
447 .Input("var: resource")
448 .Input("accum: resource")
449 .Input("lr: T")
450 .Input("l1: T")
451 .Input("l2: T")
452 .Input("grad: T")
453 .Input("indices: Tindices")
454 .Attr("T: numbertype")
455 .Attr("Tindices: {int32, int64}")
456 .Attr("use_locking: bool = false")
__anoncd5e4d381402(InferenceContext* c) 457 .SetShapeFn([](InferenceContext* c) {
458 return ApplyProximalAdagradShapeFn(c, true /* sparse */);
459 });
460
ApplyFtrlShapeFn(InferenceContext * c,bool sparse)461 static Status ApplyFtrlShapeFn(InferenceContext* c, bool sparse) {
462 ShapeHandle unused;
463 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
464 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
465 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // linear
466 TF_RETURN_IF_ERROR(
467 HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
468 int idx = sparse ? 5 : 4;
469 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr
470 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l1
471 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // l2
472 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // lr_power
473 if (c->num_outputs() > 0) {
474 c->set_output(0, s);
475 }
476 return Status::OK();
477 }
478
479 REGISTER_OP("ApplyFtrl")
480 .Input("var: Ref(T)")
481 .Input("accum: Ref(T)")
482 .Input("linear: Ref(T)")
483 .Input("grad: T")
484 .Input("lr: T")
485 .Input("l1: T")
486 .Input("l2: T")
487 .Input("lr_power: T")
488 .Output("out: Ref(T)")
489 .Attr("T: numbertype")
490 .Attr("use_locking: bool = false")
__anoncd5e4d381502(InferenceContext* c) 491 .SetShapeFn([](InferenceContext* c) {
492 return ApplyFtrlShapeFn(c, false /* sparse */);
493 });
494
495 REGISTER_OP("SparseApplyFtrl")
496 .Input("var: Ref(T)")
497 .Input("accum: Ref(T)")
498 .Input("linear: Ref(T)")
499 .Input("grad: T")
500 .Input("indices: Tindices")
501 .Input("lr: T")
502 .Input("l1: T")
503 .Input("l2: T")
504 .Input("lr_power: T")
505 .Output("out: Ref(T)")
506 .Attr("T: numbertype")
507 .Attr("Tindices: {int32, int64}")
508 .Attr("use_locking: bool = false")
__anoncd5e4d381602(InferenceContext* c) 509 .SetShapeFn([](InferenceContext* c) {
510 return ApplyFtrlShapeFn(c, true /* sparse */);
511 });
512
513 REGISTER_OP("ResourceApplyFtrl")
514 .Input("var: resource")
515 .Input("accum: resource")
516 .Input("linear: resource")
517 .Input("grad: T")
518 .Input("lr: T")
519 .Input("l1: T")
520 .Input("l2: T")
521 .Input("lr_power: T")
522 .Attr("T: numbertype")
523 .Attr("use_locking: bool = false")
__anoncd5e4d381702(InferenceContext* c) 524 .SetShapeFn([](InferenceContext* c) {
525 return ApplyFtrlShapeFn(c, false /* sparse */);
526 });
527
528 REGISTER_OP("ResourceSparseApplyFtrl")
529 .Input("var: resource")
530 .Input("accum: resource")
531 .Input("linear: resource")
532 .Input("grad: T")
533 .Input("indices: Tindices")
534 .Input("lr: T")
535 .Input("l1: T")
536 .Input("l2: T")
537 .Input("lr_power: T")
538 .Attr("T: numbertype")
539 .Attr("Tindices: {int32, int64}")
540 .Attr("use_locking: bool = false")
__anoncd5e4d381802(InferenceContext* c) 541 .SetShapeFn([](InferenceContext* c) {
542 return ApplyFtrlShapeFn(c, true /* sparse */);
543 });
544
545 REGISTER_OP("ApplyFtrlV2")
546 .Input("var: Ref(T)")
547 .Input("accum: Ref(T)")
548 .Input("linear: Ref(T)")
549 .Input("grad: T")
550 .Input("lr: T")
551 .Input("l1: T")
552 .Input("l2: T")
553 .Input("l2_shrinkage: T")
554 .Input("lr_power: T")
555 .Output("out: Ref(T)")
556 .Attr("T: numbertype")
557 .Attr("use_locking: bool = false")
__anoncd5e4d381902(InferenceContext* c) 558 .SetShapeFn([](InferenceContext* c) {
559 return ApplyFtrlShapeFn(c, false /* sparse */);
560 });
561
562 REGISTER_OP("SparseApplyFtrlV2")
563 .Input("var: Ref(T)")
564 .Input("accum: Ref(T)")
565 .Input("linear: Ref(T)")
566 .Input("grad: T")
567 .Input("indices: Tindices")
568 .Input("lr: T")
569 .Input("l1: T")
570 .Input("l2: T")
571 .Input("l2_shrinkage: T")
572 .Input("lr_power: T")
573 .Output("out: Ref(T)")
574 .Attr("T: numbertype")
575 .Attr("Tindices: {int32, int64}")
576 .Attr("use_locking: bool = false")
__anoncd5e4d381a02(InferenceContext* c) 577 .SetShapeFn([](InferenceContext* c) {
578 return ApplyFtrlShapeFn(c, true /* sparse */);
579 });
580
581 REGISTER_OP("ResourceApplyFtrlV2")
582 .Input("var: resource")
583 .Input("accum: resource")
584 .Input("linear: resource")
585 .Input("grad: T")
586 .Input("lr: T")
587 .Input("l1: T")
588 .Input("l2: T")
589 .Input("l2_shrinkage: T")
590 .Input("lr_power: T")
591 .Attr("T: numbertype")
592 .Attr("use_locking: bool = false")
__anoncd5e4d381b02(InferenceContext* c) 593 .SetShapeFn([](InferenceContext* c) {
594 return ApplyFtrlShapeFn(c, false /* sparse */);
595 });
596
597 REGISTER_OP("ResourceSparseApplyFtrlV2")
598 .Input("var: resource")
599 .Input("accum: resource")
600 .Input("linear: resource")
601 .Input("grad: T")
602 .Input("indices: Tindices")
603 .Input("lr: T")
604 .Input("l1: T")
605 .Input("l2: T")
606 .Input("l2_shrinkage: T")
607 .Input("lr_power: T")
608 .Attr("T: numbertype")
609 .Attr("Tindices: {int32, int64}")
610 .Attr("use_locking: bool = false")
__anoncd5e4d381c02(InferenceContext* c) 611 .SetShapeFn([](InferenceContext* c) {
612 return ApplyFtrlShapeFn(c, true /* sparse */);
613 });
614
ApplyMomentumShapeFn(InferenceContext * c,bool sparse)615 static Status ApplyMomentumShapeFn(InferenceContext* c, bool sparse) {
616 ShapeHandle unused;
617 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
618 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // accum
619 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
620 TF_RETURN_IF_ERROR(
621 HandleGradAndIndicesInputs(c, sparse, 3 /* grad_idx */, &s));
622 int idx = sparse ? 5 : 4;
623 TF_RETURN_IF_ERROR(c->WithRank(c->input(idx++), 0, &unused)); // momentum
624 if (c->num_outputs() > 0) {
625 c->set_output(0, s);
626 }
627 return Status::OK();
628 }
629
630 REGISTER_OP("ApplyMomentum")
631 .Input("var: Ref(T)")
632 .Input("accum: Ref(T)")
633 .Input("lr: T")
634 .Input("grad: T")
635 .Input("momentum: T")
636 .Output("out: Ref(T)")
637 .Attr("T: numbertype")
638 .Attr("use_locking: bool = false")
639 .Attr("use_nesterov: bool = false")
__anoncd5e4d381d02(InferenceContext* c) 640 .SetShapeFn([](InferenceContext* c) {
641 return ApplyMomentumShapeFn(c, false /* sparse */);
642 });
643
644 REGISTER_OP("SparseApplyMomentum")
645 .Input("var: Ref(T)")
646 .Input("accum: Ref(T)")
647 .Input("lr: T")
648 .Input("grad: T")
649 .Input("indices: Tindices")
650 .Input("momentum: T")
651 .Output("out: Ref(T)")
652 .Attr("T: numbertype")
653 .Attr("Tindices: {int32, int64}")
654 .Attr("use_locking: bool = false")
655 .Attr("use_nesterov: bool = false")
__anoncd5e4d381e02(InferenceContext* c) 656 .SetShapeFn([](InferenceContext* c) {
657 return ApplyMomentumShapeFn(c, true /* sparse */);
658 });
659
660 REGISTER_OP("ResourceApplyMomentum")
661 .Input("var: resource")
662 .Input("accum: resource")
663 .Input("lr: T")
664 .Input("grad: T")
665 .Input("momentum: T")
666 .Attr("T: numbertype")
667 .Attr("use_locking: bool = false")
668 .Attr("use_nesterov: bool = false")
__anoncd5e4d381f02(InferenceContext* c) 669 .SetShapeFn([](InferenceContext* c) {
670 return ApplyMomentumShapeFn(c, false /* sparse */);
671 });
672
673 REGISTER_OP("ResourceSparseApplyMomentum")
674 .Input("var: resource")
675 .Input("accum: resource")
676 .Input("lr: T")
677 .Input("grad: T")
678 .Input("indices: Tindices")
679 .Input("momentum: T")
680 .Attr("T: numbertype")
681 .Attr("Tindices: {int32, int64}")
682 .Attr("use_locking: bool = false")
683 .Attr("use_nesterov: bool = false")
__anoncd5e4d382002(InferenceContext* c) 684 .SetShapeFn([](InferenceContext* c) {
685 return ApplyMomentumShapeFn(c, true /* sparse */);
686 });
687
688 REGISTER_OP("ResourceApplyKerasMomentum")
689 .Input("var: resource")
690 .Input("accum: resource")
691 .Input("lr: T")
692 .Input("grad: T")
693 .Input("momentum: T")
694 .Attr("T: numbertype")
695 .Attr("use_locking: bool = false")
696 .Attr("use_nesterov: bool = false")
__anoncd5e4d382102(InferenceContext* c) 697 .SetShapeFn([](InferenceContext* c) {
698 return ApplyMomentumShapeFn(c, false /* sparse */);
699 });
700
701 REGISTER_OP("ResourceSparseApplyKerasMomentum")
702 .Input("var: resource")
703 .Input("accum: resource")
704 .Input("lr: T")
705 .Input("grad: T")
706 .Input("indices: Tindices")
707 .Input("momentum: T")
708 .Attr("T: numbertype")
709 .Attr("Tindices: {int32, int64}")
710 .Attr("use_locking: bool = false")
711 .Attr("use_nesterov: bool = false")
__anoncd5e4d382202(InferenceContext* c) 712 .SetShapeFn([](InferenceContext* c) {
713 return ApplyMomentumShapeFn(c, true /* sparse */);
714 });
715
ApplyAdamShapeFn(InferenceContext * c,bool sparse)716 static Status ApplyAdamShapeFn(InferenceContext* c, bool sparse) {
717 ShapeHandle unused;
718 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
719 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
720 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
721 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
722 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta2_power
723 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // lr
724 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta1
725 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta2
726 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // epsilon
727 TF_RETURN_IF_ERROR(
728 HandleGradAndIndicesInputs(c, sparse, 9 /* grad_idx */, &s));
729 if (c->num_outputs() > 0) {
730 c->set_output(0, s);
731 }
732 return Status::OK();
733 }
734
735 REGISTER_OP("ApplyAdam")
736 .Input("var: Ref(T)")
737 .Input("m: Ref(T)")
738 .Input("v: Ref(T)")
739 .Input("beta1_power: T")
740 .Input("beta2_power: T")
741 .Input("lr: T")
742 .Input("beta1: T")
743 .Input("beta2: T")
744 .Input("epsilon: T")
745 .Input("grad: T")
746 .Output("out: Ref(T)")
747 .Attr("T: numbertype")
748 .Attr("use_locking: bool = false")
749 .Attr("use_nesterov: bool = false")
__anoncd5e4d382302(InferenceContext* c) 750 .SetShapeFn([](InferenceContext* c) {
751 return ApplyAdamShapeFn(c, false /* sparse */);
752 });
753
754 REGISTER_OP("ResourceApplyAdam")
755 .Input("var: resource")
756 .Input("m: resource")
757 .Input("v: resource")
758 .Input("beta1_power: T")
759 .Input("beta2_power: T")
760 .Input("lr: T")
761 .Input("beta1: T")
762 .Input("beta2: T")
763 .Input("epsilon: T")
764 .Input("grad: T")
765 .Attr("T: numbertype")
766 .Attr("use_locking: bool = false")
767 .Attr("use_nesterov: bool = false")
__anoncd5e4d382402(InferenceContext* c) 768 .SetShapeFn([](InferenceContext* c) {
769 return ApplyAdamShapeFn(c, false /* sparse */);
770 });
771
ApplyAdamWithAmsgradShapeFn(InferenceContext * c,bool sparse)772 static Status ApplyAdamWithAmsgradShapeFn(InferenceContext* c, bool sparse) {
773 ShapeHandle unused;
774 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
775 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
776 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
777 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 3), &s)); // vhat
778 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // beta1_power
779 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta2_power
780 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // lr
781 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // beta1
782 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); // beta2
783 TF_RETURN_IF_ERROR(c->WithRank(c->input(9), 0, &unused)); // epsilon
784 TF_RETURN_IF_ERROR(
785 HandleGradAndIndicesInputs(c, sparse, 10 /* grad_idx */, &s));
786 if (c->num_outputs() > 0) {
787 c->set_output(0, s);
788 }
789 return Status::OK();
790 }
791
792 REGISTER_OP("ResourceApplyAdamWithAmsgrad")
793 .Input("var: resource")
794 .Input("m: resource")
795 .Input("v: resource")
796 .Input("vhat: resource")
797 .Input("beta1_power: T")
798 .Input("beta2_power: T")
799 .Input("lr: T")
800 .Input("beta1: T")
801 .Input("beta2: T")
802 .Input("epsilon: T")
803 .Input("grad: T")
804 .Attr("T: numbertype")
805 .Attr("use_locking: bool = false")
__anoncd5e4d382502(InferenceContext* c) 806 .SetShapeFn([](InferenceContext* c) {
807 return ApplyAdamWithAmsgradShapeFn(c, false /* sparse */);
808 });
809
ApplyAdaMaxShapeFn(InferenceContext * c,bool sparse)810 static Status ApplyAdaMaxShapeFn(InferenceContext* c, bool sparse) {
811 ShapeHandle unused;
812 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
813 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
814 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // v
815 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // beta1_power
816 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
817 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta1
818 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // beta2
819 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
820 TF_RETURN_IF_ERROR(
821 HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
822 if (c->num_outputs() > 0) {
823 c->set_output(0, s);
824 }
825 return Status::OK();
826 }
827
828 REGISTER_OP("ApplyAdaMax")
829 .Input("var: Ref(T)")
830 .Input("m: Ref(T)")
831 .Input("v: Ref(T)")
832 .Input("beta1_power: T")
833 .Input("lr: T")
834 .Input("beta1: T")
835 .Input("beta2: T")
836 .Input("epsilon: T")
837 .Input("grad: T")
838 .Output("out: Ref(T)")
839 .Attr("T: numbertype")
840 .Attr("use_locking: bool = false")
__anoncd5e4d382602(InferenceContext* c) 841 .SetShapeFn([](InferenceContext* c) {
842 return ApplyAdaMaxShapeFn(c, false /* sparse */);
843 });
844
845 REGISTER_OP("ResourceApplyAdaMax")
846 .Input("var: resource")
847 .Input("m: resource")
848 .Input("v: resource")
849 .Input("beta1_power: T")
850 .Input("lr: T")
851 .Input("beta1: T")
852 .Input("beta2: T")
853 .Input("epsilon: T")
854 .Input("grad: T")
855 .Attr("T: numbertype")
856 .Attr("use_locking: bool = false")
__anoncd5e4d382702(InferenceContext* c) 857 .SetShapeFn([](InferenceContext* c) {
858 return ApplyAdaMaxShapeFn(c, false /* sparse */);
859 });
860
ApplyRMSPropShapeFn(InferenceContext * c,bool sparse)861 static Status ApplyRMSPropShapeFn(InferenceContext* c, bool sparse) {
862 ShapeHandle unused;
863 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
864 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // ms
865 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // mom
866 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // lr
867 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // rho
868 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // momentum
869 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // epsilon
870 TF_RETURN_IF_ERROR(
871 HandleGradAndIndicesInputs(c, sparse, 7 /* grad_idx */, &s));
872 if (c->num_outputs() > 0) {
873 c->set_output(0, s);
874 }
875 return Status::OK();
876 }
877
ApplyCenteredRMSPropShapeFn(InferenceContext * c,bool sparse)878 static Status ApplyCenteredRMSPropShapeFn(InferenceContext* c, bool sparse) {
879 ShapeHandle unused;
880 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
881 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // ms
882 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 2), &s)); // mg
883 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 3), &s)); // mom
884 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // lr
885 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // rho
886 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); // momentum
887 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); // epsilon
888 TF_RETURN_IF_ERROR(
889 HandleGradAndIndicesInputs(c, sparse, 8 /* grad_idx */, &s));
890 if (c->num_outputs() > 0) {
891 c->set_output(0, s);
892 }
893 return Status::OK();
894 }
895
896 REGISTER_OP("ApplyRMSProp")
897 .Input("var: Ref(T)")
898 .Input("ms: Ref(T)")
899 .Input("mom: Ref(T)")
900 .Input("lr: T")
901 .Input("rho: T")
902 .Input("momentum: T")
903 .Input("epsilon: T")
904 .Input("grad: T")
905 .Output("out: Ref(T)")
906 .Attr("T: numbertype")
907 .Attr("use_locking: bool = false")
__anoncd5e4d382802(InferenceContext* c) 908 .SetShapeFn([](InferenceContext* c) {
909 return ApplyRMSPropShapeFn(c, false /* sparse */);
910 });
911
912 REGISTER_OP("ApplyCenteredRMSProp")
913 .Input("var: Ref(T)")
914 .Input("mg: Ref(T)")
915 .Input("ms: Ref(T)")
916 .Input("mom: Ref(T)")
917 .Input("lr: T")
918 .Input("rho: T")
919 .Input("momentum: T")
920 .Input("epsilon: T")
921 .Input("grad: T")
922 .Output("out: Ref(T)")
923 .Attr("T: numbertype")
924 .Attr("use_locking: bool = false")
__anoncd5e4d382902(InferenceContext* c) 925 .SetShapeFn([](InferenceContext* c) {
926 return ApplyCenteredRMSPropShapeFn(c, false /* sparse */);
927 });
928
929 REGISTER_OP("SparseApplyRMSProp")
930 .Input("var: Ref(T)")
931 .Input("ms: Ref(T)")
932 .Input("mom: Ref(T)")
933 .Input("lr: T")
934 .Input("rho: T")
935 .Input("momentum: T")
936 .Input("epsilon: T")
937 .Input("grad: T")
938 .Input("indices: Tindices")
939 .Output("out: Ref(T)")
940 .Attr("T: numbertype")
941 .Attr("Tindices: {int32, int64}")
942 .Attr("use_locking: bool = false")
__anoncd5e4d382a02(InferenceContext* c) 943 .SetShapeFn([](InferenceContext* c) {
944 return ApplyRMSPropShapeFn(c, true /* sparse */);
945 });
946
947 REGISTER_OP("SparseApplyCenteredRMSProp")
948 .Input("var: Ref(T)")
949 .Input("mg: Ref(T)")
950 .Input("ms: Ref(T)")
951 .Input("mom: Ref(T)")
952 .Input("lr: T")
953 .Input("rho: T")
954 .Input("momentum: T")
955 .Input("epsilon: T")
956 .Input("grad: T")
957 .Input("indices: Tindices")
958 .Output("out: Ref(T)")
959 .Attr("T: numbertype")
960 .Attr("Tindices: {int32, int64}")
961 .Attr("use_locking: bool = false")
__anoncd5e4d382b02(InferenceContext* c) 962 .SetShapeFn([](InferenceContext* c) {
963 return ApplyCenteredRMSPropShapeFn(c, true /* sparse */);
964 });
965
966 REGISTER_OP("ResourceApplyRMSProp")
967 .Input("var: resource")
968 .Input("ms: resource")
969 .Input("mom: resource")
970 .Input("lr: T")
971 .Input("rho: T")
972 .Input("momentum: T")
973 .Input("epsilon: T")
974 .Input("grad: T")
975 .Attr("T: numbertype")
976 .Attr("use_locking: bool = false")
__anoncd5e4d382c02(InferenceContext* c) 977 .SetShapeFn([](InferenceContext* c) {
978 return ApplyRMSPropShapeFn(c, false /* sparse */);
979 });
980
981 REGISTER_OP("ResourceApplyCenteredRMSProp")
982 .Input("var: resource")
983 .Input("mg: resource")
984 .Input("ms: resource")
985 .Input("mom: resource")
986 .Input("lr: T")
987 .Input("rho: T")
988 .Input("momentum: T")
989 .Input("epsilon: T")
990 .Input("grad: T")
991 .Attr("T: numbertype")
992 .Attr("use_locking: bool = false")
__anoncd5e4d382d02(InferenceContext* c) 993 .SetShapeFn([](InferenceContext* c) {
994 return ApplyCenteredRMSPropShapeFn(c, false /* sparse */);
995 });
996
997 REGISTER_OP("ResourceSparseApplyRMSProp")
998 .Input("var: resource")
999 .Input("ms: resource")
1000 .Input("mom: resource")
1001 .Input("lr: T")
1002 .Input("rho: T")
1003 .Input("momentum: T")
1004 .Input("epsilon: T")
1005 .Input("grad: T")
1006 .Input("indices: Tindices")
1007 .Attr("T: numbertype")
1008 .Attr("Tindices: {int32, int64}")
1009 .Attr("use_locking: bool = false")
__anoncd5e4d382e02(InferenceContext* c) 1010 .SetShapeFn([](InferenceContext* c) {
1011 return ApplyRMSPropShapeFn(c, true /* sparse */);
1012 });
1013
1014 REGISTER_OP("ResourceSparseApplyCenteredRMSProp")
1015 .Input("var: resource")
1016 .Input("mg: resource")
1017 .Input("ms: resource")
1018 .Input("mom: resource")
1019 .Input("lr: T")
1020 .Input("rho: T")
1021 .Input("momentum: T")
1022 .Input("epsilon: T")
1023 .Input("grad: T")
1024 .Input("indices: Tindices")
1025 .Attr("T: numbertype")
1026 .Attr("Tindices: {int32, int64}")
1027 .Attr("use_locking: bool = false")
__anoncd5e4d382f02(InferenceContext* c) 1028 .SetShapeFn([](InferenceContext* c) {
1029 return ApplyCenteredRMSPropShapeFn(c, true /* sparse */);
1030 });
1031
ApplyAddSignShapeFn(InferenceContext * c,bool sparse)1032 static Status ApplyAddSignShapeFn(InferenceContext* c, bool sparse) {
1033 ShapeHandle unused;
1034 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
1035 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
1036 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
1037 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // alpha
1038 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_decay
1039 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta
1040 TF_RETURN_IF_ERROR(
1041 HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
1042 if (c->num_outputs() > 0) {
1043 c->set_output(0, s);
1044 }
1045 return Status::OK();
1046 }
1047
1048 REGISTER_OP("ApplyAddSign")
1049 .Input("var: Ref(T)")
1050 .Input("m: Ref(T)")
1051 .Input("lr: T")
1052 .Input("alpha: T")
1053 .Input("sign_decay: T")
1054 .Input("beta: T")
1055 .Input("grad: T")
1056 .Output("out: Ref(T)")
1057 .Attr("T: numbertype")
1058 .Attr("use_locking: bool = false")
__anoncd5e4d383002(InferenceContext* c) 1059 .SetShapeFn([](InferenceContext* c) {
1060 return ApplyAddSignShapeFn(c, /*sparse=*/false);
1061 });
1062
1063 REGISTER_OP("ResourceApplyAddSign")
1064 .Input("var: resource")
1065 .Input("m: resource")
1066 .Input("lr: T")
1067 .Input("alpha: T")
1068 .Input("sign_decay: T")
1069 .Input("beta: T")
1070 .Input("grad: T")
1071 .Attr("T: numbertype")
1072 .Attr("use_locking: bool = false")
__anoncd5e4d383102(InferenceContext* c) 1073 .SetShapeFn([](InferenceContext* c) {
1074 return ApplyAddSignShapeFn(c, /*sparse=*/false);
1075 });
1076
ApplyPowerSignShapeFn(InferenceContext * c,bool sparse)1077 static Status ApplyPowerSignShapeFn(InferenceContext* c, bool sparse) {
1078 ShapeHandle unused;
1079 ShapeHandle s = ShapeOrHandleShape(c, 0); // var
1080 TF_RETURN_IF_ERROR(c->Merge(s, ShapeOrHandleShape(c, 1), &s)); // m
1081 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); // lr
1082 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); // logbase
1083 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); // sign_delay
1084 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); // beta
1085 TF_RETURN_IF_ERROR(
1086 HandleGradAndIndicesInputs(c, sparse, 6 /* grad_idx */, &s));
1087 if (c->num_outputs() > 0) {
1088 c->set_output(0, s);
1089 }
1090 return Status::OK();
1091 }
1092
1093 REGISTER_OP("ApplyPowerSign")
1094 .Input("var: Ref(T)")
1095 .Input("m: Ref(T)")
1096 .Input("lr: T")
1097 .Input("logbase: T")
1098 .Input("sign_decay: T")
1099 .Input("beta: T")
1100 .Input("grad: T")
1101 .Output("out: Ref(T)")
1102 .Attr("T: numbertype")
1103 .Attr("use_locking: bool = false")
__anoncd5e4d383202(InferenceContext* c) 1104 .SetShapeFn([](InferenceContext* c) {
1105 return ApplyPowerSignShapeFn(c, /*sparse=*/false);
1106 });
1107
1108 REGISTER_OP("ResourceApplyPowerSign")
1109 .Input("var: resource")
1110 .Input("m: resource")
1111 .Input("lr: T")
1112 .Input("logbase: T")
1113 .Input("sign_decay: T")
1114 .Input("beta: T")
1115 .Input("grad: T")
1116 .Attr("T: numbertype")
1117 .Attr("use_locking: bool = false")
__anoncd5e4d383302(InferenceContext* c) 1118 .SetShapeFn([](InferenceContext* c) {
1119 return ApplyPowerSignShapeFn(c, /*sparse=*/false);
1120 });
1121
1122 } // namespace tensorflow
1123