1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 REGISTER_OP("AddN")
28 .Input("inputs: N * T")
29 .Output("sum: T")
30 .Attr("N: int >= 1")
31 .Attr("T: {numbertype, variant}")
32 .SetIsCommutative()
33 .SetIsAggregate()
__anon99f049890102(InferenceContext* c) 34 .SetShapeFn([](InferenceContext* c) {
35 ShapeHandle cur = c->input(c->num_inputs() - 1);
36 for (int i = c->num_inputs() - 2; i >= 0; --i) {
37 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
38 "From merging shape ", i,
39 " with other shapes.");
40 }
41 c->set_output(0, cur);
42
43 DataType dtype;
44 TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
45
46 if (dtype != DT_VARIANT) {
47 // Exit early if not DT_VARIANT.
48 return Status::OK();
49 } else {
50 // DT_VARIANT shape handle shape inference. All sizes and dtypes must
51 // be the same; all shapes must be compatible via Merge.
52 std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
53 auto* shapes_and_types =
54 c->input_handle_shapes_and_types(c->num_inputs() - 1);
55 if (shapes_and_types) {
56 cur_shapes_and_types = *shapes_and_types;
57 }
58
59 for (int i = c->num_inputs() - 2; i >= 0; --i) {
60 auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
61 if (!shapes_and_types && shapes_and_types_i) {
62 // TODO(ebrevdo): Find cases where this happens and fix their shape
63 // inference. If we are calling AddN on variant types, they should
64 // all have consistent shape_and_type info.
65 shapes_and_types = shapes_and_types_i;
66 } else if (shapes_and_types && shapes_and_types_i) {
67 if (shapes_and_types_i->size() != shapes_and_types->size()) {
68 return errors::InvalidArgument(
69 "shapes_and_types[", i,
70 "].size() == ", shapes_and_types_i->size(),
71 " != shapes_and_types[0].size() == ",
72 shapes_and_types->size());
73 }
74 for (int j = 0; j < shapes_and_types->size(); ++j) {
75 if (shapes_and_types->at(j).dtype !=
76 shapes_and_types_i->at(j).dtype) {
77 return errors::InvalidArgument(
78 "shapes_and_types[", i, "][", j, "].dtype() == ",
79 DataTypeString(shapes_and_types_i->at(j).dtype),
80 " != shapes_and_types[0][", j, "].dtype == ",
81 DataTypeString(shapes_and_types->at(j).dtype));
82 }
83 TF_RETURN_WITH_CONTEXT_IF_ERROR(
84 c->Merge(shapes_and_types_i->at(j).shape,
85 cur_shapes_and_types.at(j).shape,
86 &cur_shapes_and_types.at(j).shape),
87 "From merging shapes_and_types[", i, "][", j, "].shape with ",
88 "shapes_and_types[0][", j, "].shape");
89 }
90 }
91 }
92 if (shapes_and_types) {
93 c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
94 }
95 return Status::OK();
96 }
97 });
98
99 // --------------------------------------------------------------------------
100
101 // Note that the following operator is just a placeholder and has no
102 // associated kernel. The code in accumulate_n_optimizer.cc replaces
103 // this placeholder with a graph of operators that do have kernels.
104 // The Python code that generates instances of this op is currently in
105 // contrib/framework/python/ops/accumulate_n_v2.py
106 REGISTER_OP("AccumulateNV2")
107 .Input("inputs: N * T")
108 .Output("sum: T")
109 .Attr("N: int >= 1")
110 .Attr("T: numbertype")
111 .Attr("shape: shape")
112 .SetIsCommutative()
113 .SetIsAggregate()
114 .SetShapeFn(shape_inference::ExplicitShape);
115
116 // --------------------------------------------------------------------------
117
118 REGISTER_OP("BatchMatMul")
119 .Input("x: T")
120 .Input("y: T")
121 .Output("output: T")
122 .Attr(
123 "T: {bfloat16, half, float, double, int32, int64, complex64, "
124 "complex128}")
125 .Attr("adj_x: bool = false")
126 .Attr("adj_y: bool = false")
127 .SetShapeFn(shape_inference::BatchMatMulShape);
128
129 REGISTER_OP("BatchMatMulV2")
130 .Input("x: T")
131 .Input("y: T")
132 .Output("output: T")
133 .Attr(
134 "T: {bfloat16, half, float, double, int16, int32, int64, complex64, "
135 "complex128}")
136 .Attr("adj_x: bool = false")
137 .Attr("adj_y: bool = false")
138 .SetShapeFn(shape_inference::BatchMatMulV2Shape);
139
140 REGISTER_OP("BatchMatMulV3")
141 .Input("x: Ta")
142 .Input("y: Tb")
143 .Output("output: Tout")
144 .Attr(
145 "Ta: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
146 "complex64, complex128}")
147 .Attr(
148 "Tb: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
149 "complex64, complex128}")
150 .Attr(
151 "Tout: {bfloat16, half, float, double, int16, int32, int64, complex64, "
152 "complex128}")
153 .Attr("adj_x: bool = false")
154 .Attr("adj_y: bool = false")
155 .SetShapeFn(shape_inference::BatchMatMulV2Shape);
156
157 #ifdef INTEL_MKL
158 REGISTER_OP("_MklBatchMatMul")
159 .Input("x: T")
160 .Input("y: T")
161 .Output("output: T")
162 .Attr("T: {bfloat16, float}")
163 .Attr("adj_x: bool = false")
164 .Attr("adj_y: bool = false")
165 .SetShapeFn(shape_inference::BatchMatMulShape);
166
167 REGISTER_OP("_MklBatchMatMulV2")
168 .Input("x: T")
169 .Input("y: T")
170 .Output("output: T")
171 .Attr("T: {bfloat16, float}")
172 .Attr("adj_x: bool = false")
173 .Attr("adj_y: bool = false")
174 .SetShapeFn(shape_inference::BatchMatMulV2Shape);
175 #endif // INTEL_MKL
176
177 // --------------------------------------------------------------------------
178 // Casting Ops
179 //
180 // NOTE: Only a smaller number of types are supported by
181 // Cast. The exact casting rule is TBD. The current
182 // implementation uses C++ static cast rules for numeric
183 // types, which may be changed in the future.
184 REGISTER_OP("Cast")
185 .Input("x: SrcT")
186 .Output("y: DstT")
187 .Attr("SrcT: type")
188 .Attr("DstT: type")
189 .Attr("Truncate: bool = false")
190 .SetShapeFn(shape_inference::UnchangedShape);
191
192 REGISTER_OP("_HostCast")
193 .Input("x: SrcT")
194 .Output("y: DstT")
195 .Attr("SrcT: type")
196 .Attr("DstT: type")
197 .Attr("Truncate: bool = false")
198 .SetShapeFn(shape_inference::UnchangedShape)
199 .Doc(R"doc(
200 Cast x of type SrcT to y of DstT.
201
202 _HostCast requires its input and produces its output in host memory.
203 )doc");
204
205 // --------------------------------------------------------------------------
206
207 REGISTER_OP("Abs")
208 .Input("x: T")
209 .Output("y: T")
210 .Attr("T: {bfloat16, half, float, double, int8, int16, int32, int64}")
211 .SetShapeFn(shape_inference::UnchangedShape);
212
213 REGISTER_OP("ComplexAbs")
214 .Input("x: T")
215 .Output("y: Tout")
216 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
217 .Attr("Tout: {float, double} = DT_FLOAT")
218 .SetShapeFn(shape_inference::UnchangedShape);
219
220 // Declares cwise unary operations signature: 't -> 't
221 #define UNARY() \
222 Input("x: T") \
223 .Output("y: T") \
224 .Attr( \
225 "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
226 "complex64, complex128}") \
227 .SetShapeFn(shape_inference::UnchangedShape)
228
229 #define UNARY_REAL() \
230 Input("x: T") \
231 .Output("y: T") \
232 .Attr("T: {bfloat16, half, float, double}") \
233 .SetShapeFn(shape_inference::UnchangedShape)
234
235 #define UNARY_COMPLEX() \
236 Input("x: T") \
237 .Output("y: T") \
238 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
239 .SetShapeFn(shape_inference::UnchangedShape)
240
241 #define UNARY_GRADIENT_COMPLEX() \
242 Input("y: T") \
243 .Input("dy: T") \
244 .Output("z: T") \
245 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
246 .SetShapeFn(shape_inference::UnchangedShape)
247
248 REGISTER_OP("Neg").UNARY();
249
250 REGISTER_OP("Inv").UNARY();
251
252 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
253
254 REGISTER_OP("Reciprocal").UNARY();
255
256 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
257
258 REGISTER_OP("Square").UNARY();
259
260 REGISTER_OP("Sqrt").UNARY_COMPLEX();
261
262 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
263
264 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
265
266 REGISTER_OP("Round").UNARY();
267
268 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
269
270 REGISTER_OP("Exp").UNARY_COMPLEX();
271
272 REGISTER_OP("Expm1").UNARY_COMPLEX();
273
274 REGISTER_OP("Log").UNARY_COMPLEX();
275
276 REGISTER_OP("Log1p").UNARY_COMPLEX();
277
278 REGISTER_OP("Sinh").UNARY_COMPLEX();
279
280 REGISTER_OP("Cosh").UNARY_COMPLEX();
281
282 REGISTER_OP("Tanh").UNARY_COMPLEX();
283
284 REGISTER_OP("Asinh").UNARY_COMPLEX();
285
286 REGISTER_OP("Acosh").UNARY_COMPLEX();
287
288 REGISTER_OP("Atanh").UNARY_COMPLEX();
289
290 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
291
292 REGISTER_OP("Lgamma").UNARY_REAL();
293
294 REGISTER_OP("Digamma").UNARY_REAL();
295
296 REGISTER_OP("Erf").UNARY_REAL();
297 REGISTER_OP("Erfinv").UNARY_REAL();
298 REGISTER_OP("Ndtri").UNARY_REAL();
299 REGISTER_OP("Erfc").UNARY_REAL();
300
301 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
302
303 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
304
305 REGISTER_OP("Sin").UNARY_COMPLEX();
306
307 REGISTER_OP("Cos").UNARY_COMPLEX();
308
309 REGISTER_OP("Tan").UNARY();
310
311 REGISTER_OP("Asin").UNARY();
312
313 REGISTER_OP("Acos").UNARY();
314
315 REGISTER_OP("Atan").UNARY();
316
317 REGISTER_OP("_UnaryOpsComposition")
318 .Input("x: T")
319 .Output("y: T")
320 .Attr("T: {float, half, double}")
321 .Attr("op_names: list(string)")
322 .SetShapeFn(shape_inference::UnchangedShape)
323 .Doc(R"doc(
324 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
325 expected to create these operators.
326 )doc");
327
328 #undef UNARY
329 #undef UNARY_REAL
330 #undef UNARY_COMPLEX
331
332 REGISTER_OP("IsNan")
333 .Input("x: T")
334 .Output("y: bool")
335 .Attr("T: {bfloat16, half, float, double}")
336 .SetShapeFn(shape_inference::UnchangedShape);
337
338 REGISTER_OP("IsInf")
339 .Input("x: T")
340 .Output("y: bool")
341 .Attr("T: {bfloat16, half, float, double}")
342 .SetShapeFn(shape_inference::UnchangedShape);
343
344 REGISTER_OP("IsFinite")
345 .Input("x: T")
346 .Output("y: bool")
347 .Attr("T: {bfloat16, half, float, double}")
348 .SetShapeFn(shape_inference::UnchangedShape);
349
350 REGISTER_OP("Sign")
351 .Input("x: T")
352 .Output("y: T")
353 .Attr(
354 "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
355 "complex64, complex128}")
356 .SetShapeFn(shape_inference::UnchangedShape);
357
358 REGISTER_OP("Floor")
359 .Input("x: T")
360 .Output("y: T")
361 .Attr("T: {bfloat16, half, float, double}")
362 .SetShapeFn(shape_inference::UnchangedShape);
363
364 REGISTER_OP("Ceil")
365 .Input("x: T")
366 .Output("y: T")
367 .Attr("T: {bfloat16, half, float, double}")
368 .SetShapeFn(shape_inference::UnchangedShape);
369
370 REGISTER_OP("Rint")
371 .Input("x: T")
372 .Output("y: T")
373 .Attr("T: {bfloat16, half, float, double}")
374 .SetShapeFn(shape_inference::UnchangedShape);
375
376 // Declares cwise binary operations signature: 't, 't -> 't.
377
378 #define BINARY_MORE() \
379 Input("x: T").Input("y: T").Output("z: T").Attr( \
380 "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
381 "uint32, uint64, int64, complex64, complex128}")
382
383 #define BINARY_FEWER() \
384 Input("x: T").Input("y: T").Output("z: T").Attr( \
385 "T: {bfloat16, half, float, double, int32, int64, complex64, " \
386 "complex128}")
387
388 REGISTER_OP("Add")
389 .Input("x: T")
390 .Input("y: T")
391 .Output("z: T")
392 .Attr(
393 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
394 "complex64, complex128, string}")
395 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
396
397 REGISTER_OP("AddV2")
398 .Input("x: T")
399 .Input("y: T")
400 .Output("z: T")
401 .Attr(
402 "T: {bfloat16, half, float, double, uint8, uint16, uint32, uint64, "
403 "int8, int16, int32, int64, complex64, complex128}")
404 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
405 .SetIsAggregate()
406 .SetIsCommutative();
407
408 #ifdef INTEL_MKL
409 REGISTER_OP("_MklAdd")
410 .Input("x: T")
411 .Input("y: T")
412 .Input("mkl_x: uint8")
413 .Input("mkl_y: uint8")
414 .Output("z: T")
415 .Output("mkl_z: uint8")
416 .Attr(
417 "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
418 "complex128, string, bfloat16}")
419 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
420 .Doc(R"doc(
421 Returns `x` + `y` element-wise.
422
423 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
424 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
425 )doc");
426
427 REGISTER_OP("_MklAddV2")
428 .Input("x: T")
429 .Input("y: T")
430 .Input("mkl_x: uint8")
431 .Input("mkl_y: uint8")
432 .Output("z: T")
433 .Output("mkl_z: uint8")
434 .Attr(
435 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
436 "complex64, complex128}")
437 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
438 .SetIsAggregate()
439 .SetIsCommutative()
440 .Doc(R"doc(
441 Returns `x` + `y` element-wise.
442 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
443 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
444 )doc");
445 #endif // INTEL_MKL
446
447 REGISTER_OP("Sub")
448 .Input("x: T")
449 .Input("y: T")
450 .Output("z: T")
451 .Attr(
452 "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
453 "int64, complex64, complex128, uint32, uint64}")
454 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
455
456 REGISTER_OP("_MklSub")
457 .BINARY_FEWER()
458 .Input("mkl_x: uint8")
459 .Input("mkl_y: uint8")
460 .Output("mkl_z: uint8")
461 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
462 .Doc(R"doc(
463 Returns x - y element-wise.
464
465 *NOTE*: `Sub` supports broadcasting. More about broadcasting
466 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
467 )doc");
468
469 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
470 shape_inference::BroadcastBinaryOpShapeFn);
471
472 REGISTER_OP("MulNoNan")
473 .Input("x: T")
474 .Input("y: T")
475 .Output("z: T")
476 .Attr("T: {bfloat16, half, float, double, complex64, complex128}")
477 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
478
479 // Note: This op is not commutative w.r.t. to all its inputs.
480 REGISTER_OP("_MklMul")
481 .BINARY_MORE()
482 .Input("mkl_x: uint8")
483 .Input("mkl_y: uint8")
484 .Output("mkl_z: uint8")
485 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
486 .Doc(R"doc(
487 Returns x * y element-wise.
488
489 *NOTE*: `Mul` supports broadcasting. More about broadcasting
490 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
491 )doc");
492
493 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
494 shape_inference::BroadcastBinaryOpShapeFn);
495
496 REGISTER_OP("DivNoNan")
497 .Input("x: T")
498 .Input("y: T")
499 .Output("z: T")
500 .Attr("T: {half, float, bfloat16, double, complex64, complex128}")
501 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
502
503 REGISTER_OP("FloorDiv")
504 .BINARY_MORE()
505 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
506
507 REGISTER_OP("TruncateDiv")
508 .BINARY_MORE()
509 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
510
511 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
512 shape_inference::BroadcastBinaryOpShapeFn);
513
514 // Note SquaredDifference implements conj(x - y)*(x - y).
515 REGISTER_OP("SquaredDifference")
516 .BINARY_FEWER()
517 .SetIsCommutative()
518 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
519
520 // Note: This op is not commutative w.r.t. to all its inputs.
521 REGISTER_OP("_MklSquaredDifference")
522 .BINARY_FEWER()
523 .Input("mkl_x: uint8")
524 .Input("mkl_y: uint8")
525 .Output("mkl_z: uint8")
526 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
527 .Doc(R"doc(
528 Returns (x - y)(x - y) element-wise.
529
530 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
531 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
532 )doc");
533
534 REGISTER_OP("Xlogy")
535 .Input("x: T")
536 .Input("y: T")
537 .Output("z: T")
538 .Attr("T: {half, float, double, complex64, complex128}")
539 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
540
541 REGISTER_OP("Xlog1py")
542 .Input("x: T")
543 .Input("y: T")
544 .Output("z: T")
545 .Attr("T: {half, float, double, complex64, complex128}")
546 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
547
548 REGISTER_OP("Xdivy")
549 .Input("x: T")
550 .Input("y: T")
551 .Output("z: T")
552 .Attr("T: {half, float, double, complex64, complex128}")
553 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
554
555 #undef BINARY_FEWER
556 #undef BINARY_MORE
557
558 REGISTER_OP("Maximum")
559 .Input("x: T")
560 .Input("y: T")
561 .Output("z: T")
562 .Attr(
563 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
564 "int32, uint32, int64, uint64}")
565 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
566
567 // Note: This op is not commutative w.r.t. to all its inputs.
568 REGISTER_OP("_MklMaximum")
569 .Input("x: T")
570 .Input("y: T")
571 .Input("mkl_x: uint8")
572 .Input("mkl_y: uint8")
573 .Output("z: T")
574 .Output("mkl_z: uint8")
575 .Attr("T: {half, float, double, int32, int64, bfloat16}")
576 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
577 .Doc(R"doc(
578 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
579
580 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
581 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
582 )doc");
583
584 REGISTER_OP("Minimum")
585 .Input("x: T")
586 .Input("y: T")
587 .Output("z: T")
588 .Attr(
589 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
590 "int32, uint32, int64, uint64}")
591 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
592
593 REGISTER_OP("Mod")
594 .Input("x: T")
595 .Input("y: T")
596 .Output("z: T")
597 .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
598 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
599
600 REGISTER_OP("FloorMod")
601 .Input("x: T")
602 .Input("y: T")
603 .Output("z: T")
604 .Attr(
605 "T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64, "
606 "bfloat16, half, float, double}")
607 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
608
609 REGISTER_OP("TruncateMod")
610 .Input("x: T")
611 .Input("y: T")
612 .Output("z: T")
613 .Attr("T: {int32, int64, bfloat16, half, float, double}")
614 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
615
616 REGISTER_OP("Pow")
617 .Input("x: T")
618 .Input("y: T")
619 .Output("z: T")
620 .Attr(
621 "T: {bfloat16, float, half, double, int8, int16, int32, int64, "
622 "complex64, complex128}")
623 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
624
625 REGISTER_OP("Igammac")
626 .Input("a: T")
627 .Input("x: T")
628 .Output("z: T")
629 .Attr("T: {bfloat16, half, float, double}")
630 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
631
632 REGISTER_OP("Igamma")
633 .Input("a: T")
634 .Input("x: T")
635 .Output("z: T")
636 .Attr("T: {bfloat16, half, float, double}")
637 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
638
639 REGISTER_OP("IgammaGradA")
640 .Input("a: T")
641 .Input("x: T")
642 .Output("z: T")
643 .Attr("T: {float, double}")
644 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
645
646 REGISTER_OP("Zeta")
647 .Input("x: T")
648 .Input("q: T")
649 .Output("z: T")
650 .Attr("T: {float, double}")
651 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
652
653 REGISTER_OP("Polygamma")
654 .Input("a: T")
655 .Input("x: T")
656 .Output("z: T")
657 .Attr("T: {float, double}")
658 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
659
660 REGISTER_OP("Atan2")
661 .Input("y: T")
662 .Input("x: T")
663 .Output("z: T")
664 .Attr("T: {bfloat16, half, float, double}")
665 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
666
667 REGISTER_OP("Betainc")
668 .Input("a: T")
669 .Input("b: T")
670 .Input("x: T")
671 .Output("z: T")
672 .Attr("T: {float, double}")
__anon99f049890202(InferenceContext* c) 673 .SetShapeFn([](InferenceContext* c) {
674 const int num_inputs = 3;
675 ShapeHandle output = c->UnknownShape();
676 int num_scalars = 0;
677 ShapeHandle some_non_scalar;
678 for (int i = 0; i < num_inputs; ++i) {
679 ShapeHandle in = c->input(i);
680 if (!c->RankKnown(in)) {
681 some_non_scalar = in;
682 // An input with unknown rank could be either a scalar (to be
683 // broadcast) or some other shape.
684 } else if (c->Rank(in) == 0) {
685 // Input is a scalar, it will be broadcast to the output shape.
686 ++num_scalars;
687 } else {
688 TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
689 some_non_scalar = output;
690 }
691 }
692
693 if (num_scalars == num_inputs - 1) {
694 // If all but one input is known to be a scalar, then output is the
695 // remaining input.
696 output = some_non_scalar;
697 } else if (num_scalars == num_inputs) {
698 // If all are scalars, output is scalar; pick the first one arbitrarily.
699 output = c->input(0);
700 }
701
702 c->set_output(0, output);
703 return Status::OK();
704 });
705
706 // --------------------------------------------------------------------------
707
708 // Declares cwise binary comparison operations signature: 't, 't -> bool,
709 // where 't has a natural total order.
710 #define COMPARISON() \
711 Input("x: T") \
712 .Input("y: T") \
713 .Output("z: bool") \
714 .Attr("T: realnumbertype") \
715 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
716
717 REGISTER_OP("Less").COMPARISON();
718
719 REGISTER_OP("LessEqual").COMPARISON();
720
721 REGISTER_OP("Greater").COMPARISON();
722
723 REGISTER_OP("GreaterEqual").COMPARISON();
724
725 #undef COMPARISON
726
727 // --------------------------------------------------------------------------
728
729 #define EQUALITY_COMPARISON() \
730 Input("x: T") \
731 .Input("y: T") \
732 .Output("z: bool") \
733 .SetIsCommutative() \
734 .Attr("T: type") \
735 .Attr("incompatible_shape_error: bool = true") \
736 .SetShapeFn([](InferenceContext* c) { \
737 ShapeHandle x = c->input(0); \
738 ShapeHandle y = c->input(1); \
739 ShapeHandle output; \
740 bool incompatible_shape_error; \
741 TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
742 &incompatible_shape_error)); \
743 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
744 c, x, y, incompatible_shape_error, &output)); \
745 c->set_output(0, output); \
746 return Status::OK(); \
747 })
748
749 REGISTER_OP("Equal").EQUALITY_COMPARISON();
750
751 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
752
753 #undef EQUALITY_COMPARISON
754
755 REGISTER_OP("ApproximateEqual")
756 .Input("x: T")
757 .Input("y: T")
758 .Output("z: bool")
759 .SetIsCommutative()
760 .Attr("T: numbertype")
761 .Attr("tolerance: float = 0.00001")
__anon99f049890302(InferenceContext* c) 762 .SetShapeFn([](InferenceContext* c) {
763 // The inputs 'x' and 'y' must have the same shape.
764 ShapeHandle data_x = c->input(0);
765 ShapeHandle data_y = c->input(1);
766 TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
767 return shape_inference::UnchangedShape(c);
768 });
769
770 // --------------------------------------------------------------------------
771
772 REGISTER_OP("LogicalNot")
773 .Input("x: bool")
774 .Output("y: bool")
775 .SetShapeFn(shape_inference::UnchangedShape);
776
777 #define BINARY_LOGICAL() \
778 Input("x: bool") \
779 .Input("y: bool") \
780 .Output("z: bool") \
781 .SetIsCommutative() \
782 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
783
784 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
785
786 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
787
788 #undef BINARY_LOGICAL
789
790 // --------------------------------------------------------------------------
791
792 REGISTER_OP("Select")
793 .Input("condition: bool")
794 .Input("t: T")
795 .Input("e: T")
796 .Output("output: T")
797 .Attr("T: type")
__anon99f049890402(InferenceContext* c) 798 .SetShapeFn([](InferenceContext* c) {
799 auto* handle_data_1 = c->input_handle_shapes_and_types(1);
800 auto* handle_data_2 = c->input_handle_shapes_and_types(2);
801 // Merge handle shape and dtype if applicable.
802 if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
803 const auto size = handle_data_1->size();
804 std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
805 if (size != handle_data_2->size()) {
806 return errors::InvalidArgument(
807 "Trying to merge handles pointing to different numbers of "
808 "tensors.");
809 }
810
811 for (int i = 0; i < size; ++i) {
812 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
813 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
814 if (s1.dtype != s2.dtype) {
815 // TODO(apassos) resolve this in the manner of b/32476923
816 return errors::InvalidArgument(
817 "Trying to merge handles pointing to different dtypes.");
818 }
819 merged_handle_data[i].dtype = s1.dtype;
820 TF_RETURN_IF_ERROR(
821 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
822 }
823
824 c->set_output_handle_shapes_and_types(0, merged_handle_data);
825 }
826
827 // The inputs 'then' and 'else' must have the same shape.
828 ShapeHandle data = c->input(1);
829 ShapeHandle other = c->input(2);
830 TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
831
832 // The input 'cond' must either have the same shape as 'then' and
833 // 'else', or be a vector if 'then' and 'else' are at least vectors.
834 ShapeHandle cond = c->input(0);
835
836 if (!c->RankKnown(cond) || !c->RankKnown(data)) {
837 c->set_output(0, data);
838 return Status::OK();
839 }
840
841 // rank of shape and data is known.
842
843 const int32_t cond_rank = c->Rank(cond);
844 const int32_t data_rank = c->Rank(data);
845
846 if (cond_rank == 0) {
847 // The rank of 'cond' is a scalar.
848 // t and e can have any shape.
849 c->set_output(0, data);
850 return Status::OK();
851 }
852
853 if (cond_rank != 1) {
854 // If 'cond' is not a vector, and not a scalar,
855 // then shape must match 'then' and 'else'
856 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
857 c->set_output(0, data);
858 return Status::OK();
859 }
860
861 if (data_rank == 0) {
862 // if 'then' and 'else' are scalar also the cond must be
863 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
864 c->set_output(0, data);
865 return Status::OK();
866 }
867
868 if (cond_rank == 1) {
869 // if the cond is a vector and the 'then' is not a scalar,
870 // the first dimension of 'then' and 'else'
871 TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
872 c->set_output(0, data);
873 return Status::OK();
874 }
875
876 c->set_output(0, data);
877
878 return Status::OK();
879 });
880
881 REGISTER_OP("SelectV2")
882 .Input("condition: bool")
883 .Input("t: T")
884 .Input("e: T")
885 .Output("output: T")
886 .Attr("T: type")
__anon99f049890502(InferenceContext* c) 887 .SetShapeFn([](InferenceContext* c) {
888 auto* handle_data_1 = c->input_handle_shapes_and_types(1);
889 auto* handle_data_2 = c->input_handle_shapes_and_types(2);
890 // Merge handle shape and dtype if applicable.
891 if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
892 const auto size = handle_data_1->size();
893 std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
894 if (size != handle_data_2->size()) {
895 return errors::InvalidArgument(
896 "Trying to merge handles pointing to different numbers of "
897 "tensors.");
898 }
899
900 for (int i = 0; i < size; ++i) {
901 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
902 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
903 if (s1.dtype != s2.dtype) {
904 // TODO(apassos) resolve this in the manner of b/32476923
905 return errors::InvalidArgument(
906 "Trying to merge handles pointing to different dtypes.");
907 }
908 merged_handle_data[i].dtype = s1.dtype;
909 TF_RETURN_IF_ERROR(
910 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
911 }
912
913 c->set_output_handle_shapes_and_types(0, merged_handle_data);
914 }
915
916 // The inputs 'cond', 'then', and 'else' must be broadcastable.
917 // TODO (yongtang): Consolidate 3-ary broadcast instead of
918 // multiple 2-ary broadcast.
919 ShapeHandle cond = c->input(0);
920 ShapeHandle then = c->input(1);
921 ShapeHandle else_ = c->input(2);
922 ShapeHandle other;
923 TF_RETURN_IF_ERROR(
924 BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
925 ShapeHandle output;
926 TF_RETURN_IF_ERROR(
927 BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
928 c->set_output(0, output);
929 return Status::OK();
930 });
931
932 // --------------------------------------------------------------------------
933
934 REGISTER_OP("MatMul")
935 .Input("a: T")
936 .Input("b: T")
937 .Output("product: T")
938 .Attr("transpose_a: bool = false")
939 .Attr("transpose_b: bool = false")
940 .Attr(
941 "T: {bfloat16, half, float, double, int32, int64, complex64, "
942 "complex128}")
943 .SetShapeFn(shape_inference::MatMulShape);
944
945 #ifdef INTEL_MKL
946 REGISTER_OP("_MklMatMul")
947 .Input("a: T")
948 .Input("b: T")
949 .Output("product: T")
950 .Attr("transpose_a: bool = false")
951 .Attr("transpose_b: bool = false")
952 .Attr("T: {bfloat16, float}")
953 .SetShapeFn(shape_inference::MatMulShape);
954 #endif // INTEL_MKL
955
956 REGISTER_OP("SparseMatMul")
957 .Input("a: Ta")
958 .Input("b: Tb")
959 .Output("product: float")
960 .Attr("transpose_a: bool = false")
961 .Attr("transpose_b: bool = false")
962 .Attr("a_is_sparse: bool = false")
963 .Attr("b_is_sparse: bool = false")
964 .Attr("Ta: {float, bfloat16} = DT_FLOAT")
965 .Attr("Tb: {float, bfloat16} = DT_FLOAT")
966 .SetShapeFn(shape_inference::MatMulShape);
967
968 REGISTER_OP("_FusedMatMul")
969 .Input("a: T")
970 .Input("b: T")
971 .Input("args: num_args * T")
972 .Output("product: T")
973 .Attr("transpose_a: bool = false")
974 .Attr("transpose_b: bool = false")
975 .Attr("T: {bfloat16, float}")
976 .Attr("num_args: int >= 0")
977 .Attr("fused_ops: list(string) = []")
978 // Attributes for the FusedBatchNorm ----------- //
979 .Attr("epsilon: float = 0.0001")
980 // Attributes for the LeakyRelu ---------------- //
981 .Attr("leakyrelu_alpha: float = 0.2")
982 // --------------------------------------------- //
983 .SetShapeFn(shape_inference::MatMulShape)
984 .Doc(R"doc(
985 Performs a MatMul followed by a specified series of operations.
986
987 The inputs to the MatMul are specified by `a` and `b`. The series of operations
988 that follows is specified by the `fused_ops` attribute, which is a list of TF op
989 names specified as strings (e.g. "Relu"). They are performed in order, where the
990 (first) input to each op is the output of the preceding op. The first input and
991 the output of each fused_op must be of type T.
992
993 Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
994 where A is one of {"Elu","Relu","Relu6"}.
995
996 * The first input to BiasAdd is the Conv2D result, and the additional BiasAdd
997 input is specified by `args`.
998 * If there is an op A specified, the output of the BiasAdd is the input to op A,
999 and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
1000 _FusedConv2D output.
1001
1002 *NOTE*: Do not invoke this operator directly in Python. Grappler is
1003 expected to create these operators.
1004 )doc");
1005
1006 // --------------------------------------------------------------------------
1007
1008 // For operations where the output is a reduction function along some
1009 // dimensions of the input.
1010 REGISTER_OP("Sum")
1011 .Input("input: T")
1012 .Input("reduction_indices: Tidx")
1013 .Output("output: T")
1014 .Attr("keep_dims: bool = false")
1015 .Attr("T: numbertype")
1016 .Attr("Tidx: {int32, int64} = DT_INT32")
1017 .SetShapeFn(shape_inference::ReductionShape);
1018
1019 REGISTER_OP("EuclideanNorm")
1020 .Input("input: T")
1021 .Input("reduction_indices: Tidx")
1022 .Output("output: T")
1023 .Attr("keep_dims: bool = false")
1024 .Attr("T: numbertype")
1025 .Attr("Tidx: {int32, int64} = DT_INT32")
1026 .SetShapeFn(shape_inference::ReductionShape);
1027
1028 REGISTER_OP("Mean")
1029 .Input("input: T")
1030 .Input("reduction_indices: Tidx")
1031 .Output("output: T")
1032 .Attr("keep_dims: bool = false")
1033 .Attr("T: numbertype")
1034 .Attr("Tidx: {int32, int64} = DT_INT32")
1035 .SetShapeFn(shape_inference::ReductionShape);
1036
1037 REGISTER_OP("Prod")
1038 .Input("input: T")
1039 .Input("reduction_indices: Tidx")
1040 .Output("output: T")
1041 .Attr("keep_dims: bool = false")
1042 .Attr("T: numbertype")
1043 .Attr("Tidx: {int32, int64} = DT_INT32")
1044 .SetShapeFn(shape_inference::ReductionShape);
1045
1046 REGISTER_OP("Min")
1047 .Input("input: T")
1048 .Input("reduction_indices: Tidx")
1049 .Output("output: T")
1050 .Attr("keep_dims: bool = false")
1051 .Attr("T: {realnumbertype, quantizedtype}")
1052 .Attr("Tidx: {int32, int64} = DT_INT32")
1053 .SetShapeFn(shape_inference::ReductionShape);
1054
1055 REGISTER_OP("Max")
1056 .Input("input: T")
1057 .Input("reduction_indices: Tidx")
1058 .Output("output: T")
1059 .Attr("keep_dims: bool = false")
1060 .Attr("T: {realnumbertype, quantizedtype}")
1061 .Attr("Tidx: {int32, int64} = DT_INT32")
1062 .SetShapeFn(shape_inference::ReductionShape);
1063
1064 namespace {
1065
ArgOpShape(shape_inference::InferenceContext * c)1066 Status ArgOpShape(shape_inference::InferenceContext* c) {
1067 ShapeHandle dimension_shape;
1068 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
1069
1070 ShapeHandle input_shape = c->input(0);
1071 if (!c->RankKnown(input_shape)) {
1072 return shape_inference::UnknownShape(c);
1073 }
1074
1075 const int32_t input_rank = c->Rank(input_shape);
1076 if (input_rank <= 1) {
1077 // Reducing a scalar/vector must return a scalar.
1078 return shape_inference::ScalarShape(c);
1079 }
1080
1081 const Tensor* dim_t = c->input_tensor(1);
1082 if (dim_t == nullptr) {
1083 // We don't know the value of the dimension, but we
1084 // know the rank of the input, so return the correct
1085 // rank with unknown dimensions.
1086 std::vector<DimensionHandle> dims(input_rank - 1);
1087 for (int i = 0; i < dims.size(); ++i) {
1088 dims[i] = c->UnknownDim();
1089 }
1090
1091 c->set_output(0, c->MakeShape(dims));
1092 return Status::OK();
1093 }
1094
1095 int64_t dimension_val;
1096 if (dim_t->dtype() == DT_INT32) {
1097 dimension_val = dim_t->scalar<int32>()();
1098 } else {
1099 dimension_val = dim_t->scalar<int64>()();
1100 }
1101
1102 int64_t axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
1103 if (axis < 0 || axis >= input_rank) {
1104 return errors::InvalidArgument(
1105 "Dimension (", dimension_val, ") must be in the range [", -input_rank,
1106 ", ", input_rank, "), where ", input_rank,
1107 " is the number of dimensions in the input.");
1108 }
1109
1110 // Return the input shape without the dimension being reduced.
1111 std::vector<DimensionHandle> dims;
1112 for (int i = 0; i < input_rank; ++i) {
1113 if (axis != i) {
1114 dims.emplace_back(c->Dim(input_shape, i));
1115 }
1116 }
1117 c->set_output(0, c->MakeShape(dims));
1118 return Status::OK();
1119 }
1120
1121 } // namespace
1122
1123 REGISTER_OP("ArgMax")
1124 .Input("input: T")
1125 .Input("dimension: Tidx")
1126 .Output("output: output_type")
1127 .Attr("T: {numbertype, bool}")
1128 .Attr("Tidx: {int32, int64} = DT_INT32")
1129 .Attr("output_type: {int32, int64} = DT_INT64")
1130 .SetShapeFn(ArgOpShape);
1131
1132 REGISTER_OP("ArgMin")
1133 .Input("input: T")
1134 .Input("dimension: Tidx")
1135 .Output("output: output_type")
1136 .Attr("T: {numbertype, bool}")
1137 .Attr("Tidx: {int32, int64} = DT_INT32")
1138 .Attr("output_type: {int32, int64} = DT_INT64")
1139 .SetShapeFn(ArgOpShape);
1140
1141 namespace {
1142
SegmentReductionShapeFn(InferenceContext * c)1143 Status SegmentReductionShapeFn(InferenceContext* c) {
1144 ShapeHandle data_shape;
1145 ShapeHandle segment_ids_shape;
1146 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1147 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
1148
1149 ShapeHandle subshape;
1150 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1151
1152 ShapeHandle out;
1153 TF_RETURN_IF_ERROR(
1154 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1155 c->set_output(0, out);
1156 return Status::OK();
1157 }
1158
SparseSegmentReductionShapeFn(InferenceContext * c)1159 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
1160 ShapeHandle data_shape;
1161 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1162
1163 ShapeHandle indices_shape;
1164 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1165
1166 ShapeHandle segment_ids_shape;
1167 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1168
1169 // indices and segment_ids should merge cleanly.
1170 ShapeHandle unused;
1171 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1172
1173 ShapeHandle subshape;
1174 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1175
1176 ShapeHandle out;
1177 TF_RETURN_IF_ERROR(
1178 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1179 c->set_output(0, out);
1180 return Status::OK();
1181 }
1182
SparseSegmentReductionGradShapeFn(InferenceContext * c)1183 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
1184 ShapeHandle data_shape;
1185 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1186
1187 ShapeHandle indices_shape;
1188 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1189
1190 // indices and segment_ids should merge cleanly.
1191 ShapeHandle unused;
1192 TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
1193
1194 // output_dim0 should be a scalar
1195 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1196
1197 ShapeHandle subshape;
1198 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1199
1200 const Tensor* dim0 = c->input_tensor(3);
1201 ShapeHandle dim0_shape;
1202 if (dim0 == nullptr) {
1203 // We don't have the value at inference time, so the output
1204 // shape is unknown.
1205 dim0_shape = c->Vector(InferenceContext::kUnknownDim);
1206 } else {
1207 auto dim0_value = dim0->scalar<int32>()();
1208 if (dim0_value < 0) {
1209 return errors::InvalidArgument(
1210 "Cannot specify a negative value for output_dim0");
1211 }
1212 dim0_shape = c->Vector(dim0_value);
1213 }
1214
1215 ShapeHandle out;
1216 TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
1217 c->set_output(0, out);
1218 return Status::OK();
1219 }
1220
SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext * c)1221 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
1222 ShapeHandle data_shape;
1223 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1224
1225 ShapeHandle indices_shape;
1226 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1227
1228 ShapeHandle segment_ids_shape;
1229 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1230
1231 ShapeHandle num_segments_shape;
1232 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
1233
1234 // indices and segment_ids should merge cleanly.
1235 ShapeHandle unused;
1236 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1237
1238 ShapeHandle subshape;
1239 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1240
1241 ShapeHandle out;
1242 const Tensor* dim0 = c->input_tensor(3);
1243 if (dim0 == nullptr) {
1244 // We don't have the value at inference time, so the output
1245 // shape is unknown.
1246 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
1247 subshape, &out));
1248 } else {
1249 auto dim0_value = dim0->scalar<int32>()();
1250 if (dim0_value < 0) {
1251 return errors::InvalidArgument(
1252 "Cannot specify a negative value for num_segments");
1253 }
1254 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
1255 }
1256 c->set_output(0, out);
1257 return Status::OK();
1258 }
1259 } // namespace
1260
1261 REGISTER_OP("SegmentSum")
1262 .Input("data: T")
1263 .Input("segment_ids: Tindices")
1264 .Output("output: T")
1265 .Attr("T: numbertype")
1266 .Attr("Tindices: {int32,int64}")
1267 .SetShapeFn(SegmentReductionShapeFn);
1268
1269 REGISTER_OP("SegmentMean")
1270 .Input("data: T")
1271 .Input("segment_ids: Tindices")
1272 .Output("output: T")
1273 .Attr("T: numbertype")
1274 .Attr("Tindices: {int32,int64}")
1275 .SetShapeFn(SegmentReductionShapeFn);
1276
1277 REGISTER_OP("SegmentProd")
1278 .Input("data: T")
1279 .Input("segment_ids: Tindices")
1280 .Output("output: T")
1281 .Attr("T: numbertype")
1282 .Attr("Tindices: {int32,int64}")
1283 .SetShapeFn(SegmentReductionShapeFn);
1284
1285 REGISTER_OP("SegmentMin")
1286 .Input("data: T")
1287 .Input("segment_ids: Tindices")
1288 .Output("output: T")
1289 .Attr("T: realnumbertype")
1290 .Attr("Tindices: {int32,int64}")
1291 .SetShapeFn(SegmentReductionShapeFn);
1292
1293 REGISTER_OP("SegmentMax")
1294 .Input("data: T")
1295 .Input("segment_ids: Tindices")
1296 .Output("output: T")
1297 .Attr("T: realnumbertype")
1298 .Attr("Tindices: {int32,int64}")
1299 .SetShapeFn(SegmentReductionShapeFn);
1300
1301 REGISTER_OP("UnsortedSegmentSum")
1302 .Input("data: T")
1303 .Input("segment_ids: Tindices")
1304 .Input("num_segments: Tnumsegments")
1305 .Output("output: T")
1306 .Attr("T: numbertype")
1307 .Attr("Tindices: {int32,int64}")
1308 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1309 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1310
1311 REGISTER_OP("UnsortedSegmentMax")
1312 .Input("data: T")
1313 .Input("segment_ids: Tindices")
1314 .Input("num_segments: Tnumsegments")
1315 .Output("output: T")
1316 .Attr("T: realnumbertype")
1317 .Attr("Tindices: {int32,int64}")
1318 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1319 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1320
1321 REGISTER_OP("UnsortedSegmentMin")
1322 .Input("data: T")
1323 .Input("segment_ids: Tindices")
1324 .Input("num_segments: Tnumsegments")
1325 .Output("output: T")
1326 .Attr("T: realnumbertype")
1327 .Attr("Tindices: {int32,int64}")
1328 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1329 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1330
1331 REGISTER_OP("UnsortedSegmentProd")
1332 .Input("data: T")
1333 .Input("segment_ids: Tindices")
1334 .Input("num_segments: Tnumsegments")
1335 .Output("output: T")
1336 .Attr("T: numbertype")
1337 .Attr("Tindices: {int32,int64}")
1338 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1339 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1340
1341 REGISTER_OP("SparseSegmentSum")
1342 .Input("data: T")
1343 .Input("indices: Tidx")
1344 .Input("segment_ids: Tsegmentids")
1345 .Output("output: T")
1346 .Attr("T: realnumbertype")
1347 .Attr("Tidx: {int32, int64} = DT_INT32")
1348 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1349 .SetShapeFn(SparseSegmentReductionShapeFn);
1350
1351 REGISTER_OP("SparseSegmentSumWithNumSegments")
1352 .Input("data: T")
1353 .Input("indices: Tidx")
1354 .Input("segment_ids: Tsegmentids")
1355 .Input("num_segments: Tnumsegments")
1356 .Output("output: T")
1357 .Attr("T: realnumbertype")
1358 .Attr("Tidx: {int32, int64} = DT_INT32")
1359 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1360 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1361 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1362
1363 REGISTER_OP("SparseSegmentSumGrad")
1364 .Input("grad: T")
1365 .Input("indices: Tidx")
1366 .Input("segment_ids: Tsegmentids")
1367 .Input("output_dim0: int32")
1368 .Output("output: T")
1369 .Attr("T: {bfloat16, half, float, double}")
1370 .Attr("Tidx: {int32, int64} = DT_INT32")
1371 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1372 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1373
1374 REGISTER_OP("SparseSegmentMean")
1375 .Input("data: T")
1376 .Input("indices: Tidx")
1377 .Input("segment_ids: Tsegmentids")
1378 .Output("output: T")
1379 .Attr("T: {bfloat16, half, float, double}")
1380 .Attr("Tidx: {int32, int64} = DT_INT32")
1381 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1382 .SetShapeFn(SparseSegmentReductionShapeFn);
1383
1384 REGISTER_OP("SparseSegmentMeanWithNumSegments")
1385 .Input("data: T")
1386 .Input("indices: Tidx")
1387 .Input("segment_ids: Tsegmentids")
1388 .Input("num_segments: Tnumsegments")
1389 .Output("output: T")
1390 .Attr("T: {bfloat16, half, float, double}")
1391 .Attr("Tidx: {int32, int64} = DT_INT32")
1392 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1393 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1394 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1395
1396 REGISTER_OP("SparseSegmentMeanGrad")
1397 .Input("grad: T")
1398 .Input("indices: Tidx")
1399 .Input("segment_ids: Tsegmentids")
1400 .Input("output_dim0: int32")
1401 .Output("output: T")
1402 .Attr("T: {bfloat16, half, float, double}")
1403 .Attr("Tidx: {int32, int64} = DT_INT32")
1404 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1405 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1406
1407 REGISTER_OP("SparseSegmentSqrtN")
1408 .Input("data: T")
1409 .Input("indices: Tidx")
1410 .Input("segment_ids: Tsegmentids")
1411 .Output("output: T")
1412 .Attr("T: {bfloat16, half, float, double}")
1413 .Attr("Tidx: {int32, int64} = DT_INT32")
1414 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1415 .SetShapeFn(SparseSegmentReductionShapeFn);
1416
1417 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
1418 .Input("data: T")
1419 .Input("indices: Tidx")
1420 .Input("segment_ids: Tsegmentids")
1421 .Input("num_segments: Tnumsegments")
1422 .Output("output: T")
1423 .Attr("T: {bfloat16, half, float, double}")
1424 .Attr("Tidx: {int32, int64} = DT_INT32")
1425 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1426 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1427 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1428
1429 REGISTER_OP("SparseSegmentSqrtNGrad")
1430 .Input("grad: T")
1431 .Input("indices: Tidx")
1432 .Input("segment_ids: Tsegmentids")
1433 .Input("output_dim0: int32")
1434 .Output("output: T")
1435 .Attr("T: {bfloat16, half, float, double}")
1436 .Attr("Tidx: {int32, int64} = DT_INT32")
1437 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1438 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1439
1440 REGISTER_OP("All")
1441 .Input("input: bool")
1442 .Input("reduction_indices: Tidx")
1443 .Output("output: bool")
1444 .Attr("keep_dims: bool = false")
1445 .Attr("Tidx: {int32, int64} = DT_INT32")
1446 .SetShapeFn(shape_inference::ReductionShape);
1447
1448 REGISTER_OP("Any")
1449 .Input("input: bool")
1450 .Input("reduction_indices: Tidx")
1451 .Attr("keep_dims: bool = false")
1452 .Output("output: bool")
1453 .Attr("Tidx: {int32, int64} = DT_INT32")
1454 .SetShapeFn(shape_inference::ReductionShape);
1455
1456 // --------------------------------------------------------------------------
1457
1458 namespace {
1459
1460 template <typename T>
RangeSize(const Tensor * start_t,const Tensor * limit_t,const Tensor * delta_t,InferenceContext * const c)1461 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
1462 const Tensor* delta_t, InferenceContext* const c) {
1463 T start = start_t->scalar<T>()();
1464 T limit = limit_t->scalar<T>()();
1465 T delta = delta_t->scalar<T>()();
1466 if (start > limit && delta > T(0)) {
1467 return errors::InvalidArgument(
1468 "Requires start <= limit when delta > 0: ", start, "/", limit);
1469 }
1470 if (start < limit && delta < T(0)) {
1471 return errors::InvalidArgument(
1472 "Requires start >= limit when delta < 0: ", start, "/", limit);
1473 }
1474 if (delta == T(0)) {
1475 return errors::InvalidArgument("Requires delta != 0");
1476 }
1477
1478 auto size = (std::is_integral<T>::value
1479 ? ((Eigen::numext::abs(limit - start) +
1480 Eigen::numext::abs(delta) - T(1)) /
1481 Eigen::numext::abs(delta))
1482 : (Eigen::numext::ceil(
1483 Eigen::numext::abs((limit - start) / delta))));
1484 c->set_output(0, c->Vector(static_cast<int64>(size)));
1485 return Status::OK();
1486 }
1487
1488 } // namespace
1489
1490 REGISTER_OP("Range")
1491 .Input("start: Tidx")
1492 .Input("limit: Tidx")
1493 .Input("delta: Tidx")
1494 .Output("output: Tidx")
1495 .Attr(
1496 "Tidx: "
1497 "{bfloat16, half, float, double, int8, int16, int32, int64, uint32} = "
1498 "DT_INT32")
__anon99f049890902(InferenceContext* c) 1499 .SetShapeFn([](InferenceContext* c) {
1500 ShapeHandle unused;
1501 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1502 " for 'start'");
1503 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1504 " for 'limit'");
1505 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1506 " for 'delta'");
1507 const Tensor* start_t = c->input_tensor(0);
1508 const Tensor* limit_t = c->input_tensor(1);
1509 const Tensor* delta_t = c->input_tensor(2);
1510 DataType dtype;
1511 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1512 if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
1513 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1514 return Status::OK();
1515 }
1516 if (dtype == DT_INT32) {
1517 return RangeSize<int32>(start_t, limit_t, delta_t, c);
1518 } else if (dtype == DT_INT16) {
1519 return RangeSize<int16>(start_t, limit_t, delta_t, c);
1520 } else if (dtype == DT_INT8) {
1521 return RangeSize<int8>(start_t, limit_t, delta_t, c);
1522 } else if (dtype == DT_INT64) {
1523 return RangeSize<int64>(start_t, limit_t, delta_t, c);
1524 } else if (dtype == DT_UINT32) {
1525 return RangeSize<uint32>(start_t, limit_t, delta_t, c);
1526 } else if (dtype == DT_FLOAT) {
1527 return RangeSize<float>(start_t, limit_t, delta_t, c);
1528 } else if (dtype == DT_DOUBLE) {
1529 return RangeSize<double>(start_t, limit_t, delta_t, c);
1530 } else if (dtype == DT_BFLOAT16) {
1531 return RangeSize<bfloat16>(start_t, limit_t, delta_t, c);
1532 } else {
1533 return errors::InvalidArgument("Unsupported dtype", dtype);
1534 }
1535 return Status::OK();
1536 });
1537
1538 REGISTER_OP("LinSpace")
1539 .Input("start: T")
1540 .Input("stop: T")
1541 .Input("num: Tidx")
1542 .Output("output: T")
1543 .Attr("T: {bfloat16, half, float, double}")
1544 .Attr("Tidx: {int32, int64} = DT_INT32")
__anon99f049890a02(InferenceContext* c) 1545 .SetShapeFn([](InferenceContext* c) {
1546 ShapeHandle unused;
1547 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1548 " for 'start'");
1549 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1550 " for 'stop'");
1551 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1552 " for 'num'");
1553 const Tensor* num_t = c->input_tensor(2);
1554 if (num_t == nullptr) {
1555 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1556 return Status::OK();
1557 }
1558
1559 int64_t num;
1560 if (num_t->dtype() == DT_INT32) {
1561 num = num_t->scalar<int32>()();
1562 } else {
1563 num = num_t->scalar<int64>()();
1564 }
1565 if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
1566 c->set_output(0, c->Vector(num));
1567 return Status::OK();
1568 });
1569
1570 REGISTER_OP("Complex")
1571 .Input("real: T")
1572 .Input("imag: T")
1573 .Output("out: Tout")
1574 .Attr("T: {float, double} = DT_FLOAT")
1575 .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
1576 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1577
1578 REGISTER_OP("Real")
1579 .Input("input: T")
1580 .Output("output: Tout")
1581 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1582 .Attr("Tout: {float, double} = DT_FLOAT")
1583 .SetShapeFn(shape_inference::UnchangedShape);
1584
1585 REGISTER_OP("Imag")
1586 .Input("input: T")
1587 .Output("output: Tout")
1588 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1589 .Attr("Tout: {float, double} = DT_FLOAT")
1590 .SetShapeFn(shape_inference::UnchangedShape);
1591
1592 REGISTER_OP("Angle")
1593 .Input("input: T")
1594 .Output("output: Tout")
1595 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1596 .Attr("Tout: {float, double} = DT_FLOAT")
1597 .SetShapeFn(shape_inference::UnchangedShape);
1598
1599 REGISTER_OP("Conj")
1600 .Input("input: T")
1601 .Output("output: T")
1602 .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
__anon99f049890b02(InferenceContext* c) 1603 .SetShapeFn([](InferenceContext* c) {
1604 c->set_output(0, c->input(0));
1605 auto* handle_data = c->input_handle_shapes_and_types(0);
1606 if (handle_data != nullptr) {
1607 c->set_output_handle_shapes_and_types(0, *handle_data);
1608 }
1609 return Status::OK();
1610 });
1611
1612 // --------------------------------------------------------------------------
1613
1614 REGISTER_OP("Cross")
1615 .Input("a: T")
1616 .Input("b: T")
1617 .Output("product: T")
1618 .Attr("T: realnumbertype")
__anon99f049890c02(InferenceContext* c) 1619 .SetShapeFn([](InferenceContext* c) {
1620 ShapeHandle a_shape;
1621 ShapeHandle b_shape;
1622 // * Input rank >= 1.
1623 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
1624 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
1625
1626 // * Both inputs have the same shape.
1627 TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
1628
1629 // * input_shape[-1] == 3.
1630 if (c->RankKnown(a_shape)) {
1631 int rank = c->Rank(a_shape);
1632 auto dim = c->Dim(a_shape, rank - 1);
1633 TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
1634 }
1635 c->set_output(0, a_shape);
1636 return Status::OK();
1637 });
1638
1639 // --------------------------------------------------------------------------
1640
1641 REGISTER_OP("HistogramFixedWidth")
1642 .Input("values: T")
1643 .Input("value_range: T")
1644 .Input("nbins: int32")
1645 .Output("out: dtype")
1646 .Attr("T: {int32, int64, float32, float64}")
1647 .Attr("dtype: {int32, int64} = DT_INT32")
__anon99f049890d02(InferenceContext* c) 1648 .SetShapeFn([](InferenceContext* c) {
1649 // value_range should be a vector.
1650 ShapeHandle value_range_shape;
1651 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
1652 // value_range should have two elements.
1653 DimensionHandle unused;
1654 TF_RETURN_IF_ERROR(
1655 c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
1656 // nbins should be a scalar.
1657 ShapeHandle nbins_shape;
1658 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
1659
1660 // If nbins is available, set the shape from nbins.
1661 const Tensor* nbins_input = c->input_tensor(2);
1662 if (nbins_input != nullptr) {
1663 int64_t nbins;
1664 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
1665 // nbins has to be positive.
1666 if (nbins <= 0) {
1667 return errors::InvalidArgument("Requires nbins > 0: ", nbins);
1668 }
1669 c->set_output(0, c->Vector(nbins));
1670 } else {
1671 c->set_output(0, c->UnknownShapeOfRank(1));
1672 }
1673 return Status::OK();
1674 });
1675
1676 REGISTER_OP("Bincount")
1677 .Input("arr: int32")
1678 .Input("size: int32")
1679 .Input("weights: T")
1680 .Attr("T: {int32, int64, float32, float64}")
1681 .Output("bins: T")
__anon99f049890e02(InferenceContext* c) 1682 .SetShapeFn([](InferenceContext* c) {
1683 ShapeHandle unused;
1684 // The input `size` must be a scalar.
1685 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1686
1687 const Tensor* size_tensor = c->input_tensor(1);
1688 if (size_tensor == nullptr) {
1689 // Return unknown shape if size is not known.
1690 c->set_output(0, c->UnknownShapeOfRank(1));
1691 return Status::OK();
1692 }
1693
1694 // Return `[size]` shape if size is known.
1695 int32_t size_val = size_tensor->scalar<int32>()();
1696 if (size_val < 0) {
1697 return errors::InvalidArgument("size (", size_val,
1698 ") must be non-negative");
1699 }
1700 c->set_output(0, c->MakeShape({size_val}));
1701 return Status::OK();
1702 });
1703
1704 REGISTER_OP("DenseBincount")
1705 .Input("input: Tidx")
1706 .Input("size: Tidx")
1707 .Input("weights: T")
1708 .Attr("Tidx: {int32, int64}")
1709 .Attr("T: {int32, int64, float32, float64}")
1710 .Attr("binary_output: bool = false")
1711 .Output("output: T")
__anon99f049890f02(InferenceContext* c) 1712 .SetShapeFn([](InferenceContext* c) {
1713 ShapeHandle unused;
1714 // The input `input` must be at most matrix.
1715 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 2, &unused));
1716 // The input `size` must be a scalar.
1717 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1718
1719 const Tensor* size_tensor = c->input_tensor(1);
1720 if (size_tensor == nullptr) {
1721 // Return unknown shape if size is not known.
1722 c->set_output(0, c->UnknownShape());
1723 return Status::OK();
1724 }
1725
1726 int64_t size_val;
1727 DataType dtype;
1728 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1729 if (dtype == DT_INT32) {
1730 size_val = static_cast<int64>(size_tensor->scalar<int32>()());
1731 } else if (dtype == DT_INT64) {
1732 size_val = size_tensor->scalar<int64>()();
1733 } else {
1734 return errors::InvalidArgument("size dtype must be int32 or int64");
1735 }
1736 // Return `[size]` shape if size is known.
1737 if (size_val < 0) {
1738 return errors::InvalidArgument("size (", size_val,
1739 ") must be non-negative");
1740 }
1741 if (c->Rank(c->input(0)) == 1) {
1742 c->set_output(0, c->MakeShape({size_val}));
1743 } else if (c->Rank(c->input(0)) == 2) {
1744 c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val}));
1745 }
1746 return Status::OK();
1747 });
1748
1749 REGISTER_OP("SparseBincount")
1750 .Input("indices: int64")
1751 .Input("values: Tidx")
1752 .Input("dense_shape: int64")
1753 .Input("size: Tidx")
1754 .Input("weights: T")
1755 .Attr("Tidx: {int32, int64}")
1756 .Attr("T: {int32, int64, float32, float64}")
1757 .Attr("binary_output: bool = false")
1758 .Output("output: T")
__anon99f049891002(InferenceContext* c) 1759 .SetShapeFn([](InferenceContext* c) {
1760 const Tensor* size_tensor = c->input_tensor(3);
1761 if (size_tensor == nullptr) {
1762 // Return unknown shape if size is not known.
1763 c->set_output(0, c->UnknownShape());
1764 return Status::OK();
1765 }
1766
1767 int64_t size_val;
1768 DataType dtype;
1769 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1770 if (dtype == DT_INT32) {
1771 size_val = static_cast<int64>(size_tensor->scalar<int32>()());
1772 } else if (dtype == DT_INT64) {
1773 size_val = size_tensor->scalar<int64>()();
1774 } else {
1775 return errors::InvalidArgument("size dtype must be int32 or int64");
1776 }
1777 // Return `[size]` shape if size is known.
1778 if (size_val < 0) {
1779 return errors::InvalidArgument("size (", size_val,
1780 ") must be non-negative");
1781 }
1782
1783 const Tensor* shape_tensor = c->input_tensor(2);
1784 if (shape_tensor == nullptr) {
1785 // Return unknown shape if size is not known.
1786 c->set_output(0, c->UnknownShape());
1787 return Status::OK();
1788 }
1789 if (shape_tensor->NumElements() == 1) {
1790 c->set_output(0, c->MakeShape({size_val}));
1791 } else if (shape_tensor->NumElements() == 2) {
1792 c->set_output(0,
1793 c->MakeShape({shape_tensor->flat<int64>()(0), size_val}));
1794 } else {
1795 return errors::InvalidArgument("Input must be less than rank 2");
1796 }
1797 return Status::OK();
1798 });
1799
1800 REGISTER_OP("RaggedBincount")
1801 .Input("splits: int64")
1802 .Input("values: Tidx")
1803 .Input("size: Tidx")
1804 .Input("weights: T")
1805 .Attr("Tidx: {int32, int64}")
1806 .Attr("T: {int32, int64, float32, float64}")
1807 .Attr("binary_output: bool = false")
1808 .Output("output: T")
__anon99f049891102(InferenceContext* c) 1809 .SetShapeFn([](InferenceContext* c) {
1810 c->set_output(0, c->UnknownShape());
1811 return Status::OK();
1812 });
1813
1814 REGISTER_OP("Cumsum")
1815 .Input("x: T")
1816 .Input("axis: Tidx")
1817 .Attr("exclusive: bool = false")
1818 .Attr("reverse: bool = false")
1819 .Output("out: T")
1820 .Attr("T: numbertype")
1821 .Attr("Tidx: {int32, int64} = DT_INT32")
1822 .SetShapeFn(shape_inference::UnchangedShape);
1823
1824 REGISTER_OP("Cumprod")
1825 .Input("x: T")
1826 .Input("axis: Tidx")
1827 .Attr("exclusive: bool = false")
1828 .Attr("reverse: bool = false")
1829 .Output("out: T")
1830 .Attr("T: numbertype")
1831 .Attr("Tidx: {int32, int64} = DT_INT32")
1832 .SetShapeFn(shape_inference::UnchangedShape);
1833
1834 REGISTER_OP("CumulativeLogsumexp")
1835 .Input("x : T")
1836 .Input("axis: Tidx")
1837 .Attr("exclusive: bool = false")
1838 .Attr("reverse: bool = false")
1839 .Output("out: T")
1840 .Attr("T: {float16, float32, float64}")
1841 .Attr("Tidx: {int32, int64} = DT_INT32")
1842 .SetShapeFn(shape_inference::UnchangedShape);
1843
1844 REGISTER_OP("QuantizedMatMul")
1845 .Input("a: T1")
1846 .Input("b: T2")
1847 .Input("min_a: float")
1848 .Input("max_a: float")
1849 .Input("min_b: float")
1850 .Input("max_b: float")
1851 .Output("out: Toutput")
1852 .Output("min_out: float")
1853 .Output("max_out: float")
1854 .Attr("T1: quantizedtype")
1855 .Attr("T2: quantizedtype")
1856 .Attr("Toutput: quantizedtype = DT_QINT32")
1857 .Attr("transpose_a: bool = false")
1858 .Attr("transpose_b: bool = false")
1859 .Attr("Tactivation: quantizedtype = DT_QUINT8")
__anon99f049891202(InferenceContext* c) 1860 .SetShapeFn([](InferenceContext* c) {
1861 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1862 ShapeHandle unused;
1863 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1864 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1865 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1866 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1867
1868 c->set_output(1, c->Scalar());
1869 c->set_output(2, c->Scalar());
1870 return Status::OK();
1871 });
1872
1873 // Note: This op is not commutative w.r.t. to all its inputs.
1874 REGISTER_OP("QuantizedMul")
1875 .Input("x: T1")
1876 .Input("y: T2")
1877 .Input("min_x: float")
1878 .Input("max_x: float")
1879 .Input("min_y: float")
1880 .Input("max_y: float")
1881 .Output("z: Toutput")
1882 .Output("min_z: float")
1883 .Output("max_z: float")
1884 .Attr("T1: quantizedtype")
1885 .Attr("T2: quantizedtype")
1886 .Attr("Toutput: quantizedtype = DT_QINT32")
__anon99f049891302(InferenceContext* c) 1887 .SetShapeFn([](InferenceContext* c) {
1888 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1889 c->set_output(1, c->Scalar());
1890 c->set_output(2, c->Scalar());
1891 return Status::OK();
1892 });
1893
1894 // Note: This op is not commutative w.r.t. to all its inputs.
1895 REGISTER_OP("QuantizedAdd")
1896 .Input("x: T1")
1897 .Input("y: T2")
1898 .Input("min_x: float")
1899 .Input("max_x: float")
1900 .Input("min_y: float")
1901 .Input("max_y: float")
1902 .Output("z: Toutput")
1903 .Output("min_z: float")
1904 .Output("max_z: float")
1905 .Attr("T1: quantizedtype")
1906 .Attr("T2: quantizedtype")
1907 .Attr("Toutput: quantizedtype = DT_QINT32")
__anon99f049891402(InferenceContext* c) 1908 .SetShapeFn([](InferenceContext* c) {
1909 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1910 // min_x, max_x, min_y, max_y should be scalar.
1911 ShapeHandle unused;
1912 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1913 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1914 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1915 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1916
1917 c->set_output(1, c->Scalar());
1918 c->set_output(2, c->Scalar());
1919 return Status::OK();
1920 });
1921
1922 REGISTER_OP("QuantizeDownAndShrinkRange")
1923 .Input("input: Tinput")
1924 .Input("input_min: float")
1925 .Input("input_max: float")
1926 .Output("output: out_type")
1927 .Output("output_min: float")
1928 .Output("output_max: float")
1929 .Attr("Tinput: quantizedtype")
1930 .Attr("out_type: quantizedtype")
__anon99f049891502(InferenceContext* c) 1931 .SetShapeFn([](InferenceContext* c) {
1932 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1933 ShapeHandle unused;
1934 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1935 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1936 c->set_output(1, c->Scalar());
1937 c->set_output(2, c->Scalar());
1938 return Status::OK();
1939 });
1940
1941 REGISTER_OP("Requantize")
1942 .Input("input: Tinput")
1943 .Input("input_min: float")
1944 .Input("input_max: float")
1945 .Input("requested_output_min: float")
1946 .Input("requested_output_max: float")
1947 .Output("output: out_type")
1948 .Output("output_min: float")
1949 .Output("output_max: float")
1950 .Attr("Tinput: quantizedtype")
1951 .Attr("out_type: quantizedtype")
__anon99f049891602(InferenceContext* c) 1952 .SetShapeFn([](InferenceContext* c) {
1953 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1954 ShapeHandle unused;
1955 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1956 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1957 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1958 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1959 c->set_output(1, c->Scalar());
1960 c->set_output(2, c->Scalar());
1961 return Status::OK();
1962 });
1963
1964 REGISTER_OP("RequantizationRange")
1965 .Input("input: Tinput")
1966 .Input("input_min: float")
1967 .Input("input_max: float")
1968 .Output("output_min: float")
1969 .Output("output_max: float")
1970 .Attr("Tinput: quantizedtype")
__anon99f049891702(InferenceContext* c) 1971 .SetShapeFn([](InferenceContext* c) {
1972 ShapeHandle unused;
1973 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1974 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1975 c->set_output(0, c->Scalar());
1976 c->set_output(1, c->Scalar());
1977 return Status::OK();
1978 });
1979
1980 // --------------------------------------------------------------------------
1981
1982 REGISTER_OP("Bucketize")
1983 .Input("input: T")
1984 .Output("output: int32")
1985 .Attr("T: {int32, int64, float, double}")
1986 .Attr("boundaries: list(float)")
1987 .SetShapeFn(shape_inference::UnchangedShape);
1988
1989 REGISTER_OP("ClipByValue")
1990 .Input("t: T")
1991 .Input("clip_value_min: T")
1992 .Input("clip_value_max: T")
1993 .Output("output: T")
1994 .Attr("T: numbertype")
1995 .SetShapeFn(shape_inference::UnchangedShape);
1996
1997 #ifdef INTEL_MKL
1998 // Note: This op is not commutative w.r.t. to all its inputs.
1999 REGISTER_OP("_MklAddN")
2000 .Input("inputs: N * T")
2001 .Input("mkl_input: N * uint8")
2002 .Output("sum: T")
2003 .Output("mkl_sum: uint8")
2004 .Attr("N: int >= 1")
2005 .Attr("T: numbertype")
__anon99f049891802(InferenceContext* c) 2006 .SetShapeFn([](InferenceContext* c) {
2007 ShapeHandle cur = c->input(c->num_inputs() - 1);
2008 for (int i = c->num_inputs() - 2; i >= 0; --i) {
2009 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
2010 "From merging shape ", i,
2011 " with other shapes.");
2012 }
2013 c->set_output(0, cur);
2014 return Status::OK();
2015 })
2016 .Doc(R"doc(
2017 Add two input tensors element wise using mkl kernel sum.
2018 inputs: Must all be the same size and shape.
2019 )doc");
2020
2021 #endif // INTEL_MKL
2022
2023 REGISTER_OP("RequantizePerChannel")
2024 .Input("input: T")
2025 .Input("input_min: float")
2026 .Input("input_max: float")
2027 .Input("requested_output_min: float")
2028 .Input("requested_output_max: float")
2029 .Output("output: out_type")
2030 .Output("output_min: float")
2031 .Output("output_max: float")
2032 .Attr("T: quantizedtype = DT_QINT32")
2033 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon99f049891902(InferenceContext* c) 2034 .SetShapeFn([](InferenceContext* c) {
2035 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2036 ShapeHandle unused;
2037 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2038 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2039 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2040 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2041 c->set_output(1, c->Scalar());
2042 c->set_output(2, c->Scalar());
2043 return Status::OK();
2044 });
2045 REGISTER_OP("RequantizationRangePerChannel")
2046 .Input("input: T")
2047 .Input("input_min: float")
2048 .Input("input_max: float")
2049 .Output("output_min: float")
2050 .Output("output_max: float")
2051 .Attr("T: quantizedtype = DT_QINT32")
2052 .Attr("clip_value_max: float")
__anon99f049891a02(InferenceContext* c) 2053 .SetShapeFn([](InferenceContext* c) {
2054 ShapeHandle unused;
2055 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2056 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2057 c->set_output(0, c->Scalar());
2058 c->set_output(1, c->Scalar());
2059 return Status::OK();
2060 });
2061
2062 REGISTER_OP("NextAfter")
2063 .Attr("T: {float64, float32} = DT_FLOAT")
2064 .Input("x1: T")
2065 .Input("x2: T")
2066 .Output("output: T")
2067 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
2068
2069 REGISTER_OP("SobolSample")
2070 .Input("dim: int32")
2071 .Input("num_results: int32")
2072 .Input("skip: int32")
2073 .Attr("dtype: {float, double} = DT_FLOAT")
2074 .Output("samples: dtype")
__anon99f049891b02(shape_inference::InferenceContext* c) 2075 .SetShapeFn([](shape_inference::InferenceContext* c) {
2076 ShapeHandle unused;
2077
2078 // inputs must be scalars
2079 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
2080 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2081 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2082
2083 const Tensor* dim_t = c->input_tensor(0);
2084 const Tensor* num_results_t = c->input_tensor(1);
2085
2086 int32_t dim = dim_t == nullptr ? InferenceContext::kUnknownDim
2087 : dim_t->scalar<int32>()();
2088
2089 int32_t num_results = num_results_t == nullptr
2090 ? InferenceContext::kUnknownDim
2091 : num_results_t->scalar<int32>()();
2092
2093 c->set_output(0, c->Matrix(num_results, dim));
2094 return Status::OK();
2095 });
2096
2097 } // namespace tensorflow
2098