1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20
21 namespace tensorflow {
22
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26
27 REGISTER_OP("AddN")
28 .Input("inputs: N * T")
29 .Output("sum: T")
30 .Attr("N: int >= 1")
31 .Attr("T: {numbertype, variant}")
32 .SetIsCommutative()
33 .SetIsAggregate()
__anonb22bfa860102(InferenceContext* c) 34 .SetShapeFn([](InferenceContext* c) {
35 ShapeHandle cur = c->input(c->num_inputs() - 1);
36 for (int i = c->num_inputs() - 2; i >= 0; --i) {
37 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
38 "From merging shape ", i,
39 " with other shapes.");
40 }
41 c->set_output(0, cur);
42
43 DataType dtype;
44 TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
45
46 if (dtype != DT_VARIANT) {
47 // Exit early if not DT_VARIANT.
48 return Status::OK();
49 } else {
50 // DT_VARIANT shape handle shape inference. All sizes and dtypes must
51 // be the same; all shapes must be compatible via Merge.
52 std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
53 auto* shapes_and_types =
54 c->input_handle_shapes_and_types(c->num_inputs() - 1);
55 if (shapes_and_types) {
56 cur_shapes_and_types = *shapes_and_types;
57 }
58
59 for (int i = c->num_inputs() - 2; i >= 0; --i) {
60 auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
61 if (!shapes_and_types && shapes_and_types_i) {
62 // TODO(ebrevdo): Find cases where this happens and fix their shape
63 // inference. If we are calling AddN on variant types, they should
64 // all have consistent shape_and_type info.
65 shapes_and_types = shapes_and_types_i;
66 } else if (shapes_and_types && shapes_and_types_i) {
67 if (shapes_and_types_i->size() != shapes_and_types->size()) {
68 return errors::InvalidArgument(
69 "shapes_and_types[", i,
70 "].size() == ", shapes_and_types_i->size(),
71 " != shapes_and_types[0].size() == ",
72 shapes_and_types->size());
73 }
74 for (int j = 0; j < shapes_and_types->size(); ++j) {
75 if (shapes_and_types->at(j).dtype !=
76 shapes_and_types_i->at(j).dtype) {
77 return errors::InvalidArgument(
78 "shapes_and_types[", i, "][", j, "].dtype() == ",
79 DataTypeString(shapes_and_types_i->at(j).dtype),
80 " != shapes_and_types[0][", j, "].dtype == ",
81 DataTypeString(shapes_and_types->at(j).dtype));
82 }
83 TF_RETURN_WITH_CONTEXT_IF_ERROR(
84 c->Merge(shapes_and_types_i->at(j).shape,
85 cur_shapes_and_types.at(j).shape,
86 &cur_shapes_and_types.at(j).shape),
87 "From merging shapes_and_types[", i, "][", j, "].shape with ",
88 "shapes_and_types[0][", j, "].shape");
89 }
90 }
91 }
92 if (shapes_and_types) {
93 c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
94 }
95 return Status::OK();
96 }
97 });
98
99 // --------------------------------------------------------------------------
100
101 // Note that the following operator is just a placeholder and has no
102 // associated kernel. The code in accumulate_n_optimizer.cc replaces
103 // this placeholder with a graph of operators that do have kernels.
104 // The Python code that generates instances of this op is currently in
105 // contrib/framework/python/ops/accumulate_n_v2.py
106 REGISTER_OP("AccumulateNV2")
107 .Input("inputs: N * T")
108 .Output("sum: T")
109 .Attr("N: int >= 1")
110 .Attr("T: numbertype")
111 .Attr("shape: shape")
112 .SetIsCommutative()
113 .SetIsAggregate()
114 .SetShapeFn(shape_inference::ExplicitShape);
115
116 // --------------------------------------------------------------------------
117
118 REGISTER_OP("BatchMatMul")
119 .Input("x: T")
120 .Input("y: T")
121 .Output("output: T")
122 .Attr(
123 "T: {bfloat16, half, float, double, int32, int64, complex64, "
124 "complex128}")
125 .Attr("adj_x: bool = false")
126 .Attr("adj_y: bool = false")
__anonb22bfa860202(InferenceContext* c) 127 .SetShapeFn([](InferenceContext* c) {
128 ShapeHandle a_shape;
129 ShapeHandle b_shape;
130 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
131 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
132
133 // Determine output rows and cols.
134 bool adj_x;
135 bool adj_y;
136 TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
137 TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
138 DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
139 DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
140
141 // Batch dims match between inputs.
142 ShapeHandle a_batch_dims;
143 ShapeHandle b_batch_dims;
144 ShapeHandle batch_dims;
145 TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
146 TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
147 TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
148
149 // Assert inner dims match.
150 DimensionHandle unused;
151 TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
152 c->Dim(b_shape, adj_y ? -1 : -2), &unused));
153
154 ShapeHandle out;
155 TF_RETURN_IF_ERROR(c->Concatenate(
156 batch_dims, c->Matrix(output_rows, output_cols), &out));
157 c->set_output(0, out);
158 return Status::OK();
159 });
160
161 // --------------------------------------------------------------------------
162 // Casting Ops
163 //
164 // NOTE: Only a smaller number of types are supported by
165 // Cast. The exact casting rule is TBD. The current
166 // implementation uses C++ static cast rules for numeric
167 // types, which may be changed in the future.
168 REGISTER_OP("Cast")
169 .Input("x: SrcT")
170 .Output("y: DstT")
171 .Attr("SrcT: type")
172 .Attr("DstT: type")
173 .Attr("Truncate: bool = false")
174 .SetShapeFn(shape_inference::UnchangedShape);
175
176 REGISTER_OP("_HostCast")
177 .Input("x: SrcT")
178 .Output("y: DstT")
179 .Attr("SrcT: type")
180 .Attr("DstT: type")
181 .Attr("Truncate: bool = false")
182 .SetShapeFn(shape_inference::UnchangedShape)
183 .Doc(R"doc(
184 Cast x of type SrcT to y of DstT.
185
186 _HostCast requires its input and produces its output in host memory.
187 )doc");
188
189 // --------------------------------------------------------------------------
190
191 REGISTER_OP("Abs")
192 .Input("x: T")
193 .Output("y: T")
194 .Attr("T: {bfloat16, half, float, double, int32, int64}")
195 .SetShapeFn(shape_inference::UnchangedShape);
196
197 REGISTER_OP("ComplexAbs")
198 .Input("x: T")
199 .Output("y: Tout")
200 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
201 .Attr("Tout: {float, double} = DT_FLOAT")
202 .SetShapeFn(shape_inference::UnchangedShape);
203
204 // Declares cwise unary operations signature: 't -> 't
205 #define UNARY() \
206 Input("x: T") \
207 .Output("y: T") \
208 .Attr( \
209 "T: {bfloat16, half, float, double, int32, int64, complex64, " \
210 "complex128}") \
211 .SetShapeFn(shape_inference::UnchangedShape)
212
213 #define UNARY_REAL() \
214 Input("x: T") \
215 .Output("y: T") \
216 .Attr("T: {bfloat16, half, float, double}") \
217 .SetShapeFn(shape_inference::UnchangedShape)
218
219 #define UNARY_COMPLEX() \
220 Input("x: T") \
221 .Output("y: T") \
222 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
223 .SetShapeFn(shape_inference::UnchangedShape)
224
225 #define UNARY_GRADIENT_COMPLEX() \
226 Input("y: T") \
227 .Input("dy: T") \
228 .Output("z: T") \
229 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
230 .SetShapeFn(shape_inference::UnchangedShape)
231
232 REGISTER_OP("Neg").UNARY();
233
234 REGISTER_OP("Inv").UNARY();
235
236 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
237
238 REGISTER_OP("Reciprocal").UNARY();
239
240 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
241
242 REGISTER_OP("Square").UNARY();
243
244 REGISTER_OP("Sqrt").UNARY_COMPLEX();
245
246 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
247
248 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
249
250 REGISTER_OP("Round").UNARY();
251
252 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
253
254 REGISTER_OP("Exp").UNARY_COMPLEX();
255
256 REGISTER_OP("Expm1").UNARY_COMPLEX();
257
258 REGISTER_OP("Log").UNARY_COMPLEX();
259
260 REGISTER_OP("Log1p").UNARY_COMPLEX();
261
262 REGISTER_OP("Sinh").UNARY_COMPLEX();
263
264 REGISTER_OP("Cosh").UNARY_COMPLEX();
265
266 REGISTER_OP("Tanh").UNARY_COMPLEX();
267
268 REGISTER_OP("Asinh").UNARY_COMPLEX();
269
270 REGISTER_OP("Acosh").UNARY_COMPLEX();
271
272 REGISTER_OP("Atanh").UNARY_COMPLEX();
273
274 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
275
276 REGISTER_OP("Lgamma").UNARY_REAL();
277
278 REGISTER_OP("Digamma").UNARY_REAL();
279
280 REGISTER_OP("Erf").UNARY_REAL();
281
282 REGISTER_OP("Erfc").UNARY_REAL();
283
284 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
285
286 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
287
288 REGISTER_OP("Sin").UNARY_COMPLEX();
289
290 REGISTER_OP("Cos").UNARY_COMPLEX();
291
292 REGISTER_OP("Tan").UNARY();
293
294 REGISTER_OP("Asin").UNARY();
295
296 REGISTER_OP("Acos").UNARY();
297
298 REGISTER_OP("Atan").UNARY();
299
300 REGISTER_OP("BesselI0e").UNARY_REAL();
301
302 REGISTER_OP("BesselI1e").UNARY_REAL();
303
304 REGISTER_OP("_UnaryOpsComposition")
305 .Input("x: T")
306 .Output("y: T")
307 .Attr("T: {float, half, double}")
308 .Attr("op_names: list(string)")
309 .SetShapeFn(shape_inference::UnchangedShape)
310 .Doc(R"doc(
311 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
312 expected to create these operators.
313 )doc");
314
315 #undef UNARY
316 #undef UNARY_REAL
317 #undef UNARY_COMPLEX
318
319 REGISTER_OP("IsNan")
320 .Input("x: T")
321 .Output("y: bool")
322 .Attr("T: {bfloat16, half, float, double}")
323 .SetShapeFn(shape_inference::UnchangedShape);
324
325 REGISTER_OP("IsInf")
326 .Input("x: T")
327 .Output("y: bool")
328 .Attr("T: {bfloat16, half, float, double}")
329 .SetShapeFn(shape_inference::UnchangedShape);
330
331 REGISTER_OP("IsFinite")
332 .Input("x: T")
333 .Output("y: bool")
334 .Attr("T: {bfloat16, half, float, double}")
335 .SetShapeFn(shape_inference::UnchangedShape);
336
337 REGISTER_OP("Sign")
338 .Input("x: T")
339 .Output("y: T")
340 .Attr(
341 "T: {bfloat16, half, float, double, int32, int64, complex64, "
342 "complex128}")
343 .SetShapeFn(shape_inference::UnchangedShape);
344
345 REGISTER_OP("Floor")
346 .Input("x: T")
347 .Output("y: T")
348 .Attr("T: {bfloat16, half, float, double}")
349 .SetShapeFn(shape_inference::UnchangedShape);
350
351 REGISTER_OP("Ceil")
352 .Input("x: T")
353 .Output("y: T")
354 .Attr("T: {bfloat16, half, float, double}")
355 .SetShapeFn(shape_inference::UnchangedShape);
356
357 REGISTER_OP("Rint")
358 .Input("x: T")
359 .Output("y: T")
360 .Attr("T: {bfloat16, half, float, double}")
361 .SetShapeFn(shape_inference::UnchangedShape);
362
363 // Declares cwise binary operations signature: 't, 't -> 't.
364
365 #define BINARY_MORE() \
366 Input("x: T").Input("y: T").Output("z: T").Attr( \
367 "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
368 "int64, complex64, complex128}")
369
370 #define BINARY_FEWER() \
371 Input("x: T").Input("y: T").Output("z: T").Attr( \
372 "T: {bfloat16, half, float, double, int32, int64, complex64, " \
373 "complex128}")
374
375 REGISTER_OP("Add")
376 .Input("x: T")
377 .Input("y: T")
378 .Output("z: T")
379 .Attr(
380 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
381 "complex64, complex128, string}")
382 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
383
384 // TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
385 // use AddV2 (b/68646025).
386 REGISTER_OP("AddV2")
387 .Input("x: T")
388 .Input("y: T")
389 .Output("z: T")
390 .Attr(
391 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
392 "complex64, complex128}")
393 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
394 .SetIsAggregate()
395 .SetIsCommutative();
396
397 REGISTER_OP("_MklAdd")
398 .Input("x: T")
399 .Input("y: T")
400 .Input("mkl_x: uint8")
401 .Input("mkl_y: uint8")
402 .Output("z: T")
403 .Output("mkl_z: uint8")
404 .Attr(
405 "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
406 "complex128, string}")
407 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
408 .Doc(R"doc(
409 Returns `x` + `y` element-wise.
410
411 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
412 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
413 )doc");
414
415 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
416 shape_inference::BroadcastBinaryOpShapeFn);
417
418 REGISTER_OP("_MklSub")
419 .BINARY_FEWER()
420 .Input("mkl_x: uint8")
421 .Input("mkl_y: uint8")
422 .Output("mkl_z: uint8")
423 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
424 .Doc(R"doc(
425 Returns x - y element-wise.
426
427 *NOTE*: `Sub` supports broadcasting. More about broadcasting
428 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
429 )doc");
430
431 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
432 shape_inference::BroadcastBinaryOpShapeFn);
433
434 REGISTER_OP("MulNoNan")
435 .Input("x: T")
436 .Input("y: T")
437 .Output("z: T")
438 .Attr("T: {half, float, double, complex64, complex128}")
439 .SetIsCommutative()
440 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
441
442 REGISTER_OP("_MklMul")
443 .BINARY_MORE()
444 .Input("mkl_x: uint8")
445 .Input("mkl_y: uint8")
446 .Output("mkl_z: uint8")
447 .SetIsCommutative()
448 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
449 .Doc(R"doc(
450 Returns x * y element-wise.
451
452 *NOTE*: `Mul` supports broadcasting. More about broadcasting
453 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
454 )doc");
455
456 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
457 shape_inference::BroadcastBinaryOpShapeFn);
458
459 REGISTER_OP("DivNoNan")
460 .Input("x: T")
461 .Input("y: T")
462 .Output("z: T")
463 .Attr("T: {half, float, double, complex64, complex128}")
464 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
465
466 REGISTER_OP("FloorDiv")
467 .BINARY_MORE()
468 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
469
470 REGISTER_OP("TruncateDiv")
471 .BINARY_MORE()
472 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
473
474 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
475 shape_inference::BroadcastBinaryOpShapeFn);
476
477 REGISTER_OP("SquaredDifference")
478 .BINARY_FEWER()
479 .SetIsCommutative()
480 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
481
482 REGISTER_OP("_MklSquaredDifference")
483 .BINARY_FEWER()
484 .Input("mkl_x: uint8")
485 .Input("mkl_y: uint8")
486 .Output("mkl_z: uint8")
487 .SetIsCommutative()
488 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
489 .Doc(R"doc(
490 Returns (x - y)(x - y) element-wise.
491
492 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
493 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
494 )doc");
495
496 REGISTER_OP("Xlogy")
497 .Input("x: T")
498 .Input("y: T")
499 .Output("z: T")
500 .Attr("T: {half, float, double, complex64, complex128}")
501 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
502
503 REGISTER_OP("Xdivy")
504 .Input("x: T")
505 .Input("y: T")
506 .Output("z: T")
507 .Attr("T: {half, float, double, complex64, complex128}")
508 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
509
510 #undef BINARY_FEWER
511 #undef BINARY_MORE
512
513 REGISTER_OP("Maximum")
514 .Input("x: T")
515 .Input("y: T")
516 .Output("z: T")
517 .Attr("T: {bfloat16, half, float, double, int32, int64}")
518 .SetIsCommutative()
519 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
520
521 REGISTER_OP("_MklMaximum")
522 .Input("x: T")
523 .Input("y: T")
524 .Input("mkl_x: uint8")
525 .Input("mkl_y: uint8")
526 .Output("z: T")
527 .Output("mkl_z: uint8")
528 .Attr("T: {half, float, double, int32, int64}")
529 .SetIsCommutative()
530 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
531 .Doc(R"doc(
532 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
533
534 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
535 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
536 )doc");
537
538 REGISTER_OP("Minimum")
539 .Input("x: T")
540 .Input("y: T")
541 .Output("z: T")
542 .Attr("T: {bfloat16, half, float, double, int32, int64}")
543 .SetIsCommutative()
544 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
545
546 REGISTER_OP("Mod")
547 .Input("x: T")
548 .Input("y: T")
549 .Output("z: T")
550 .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
551 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
552
553 REGISTER_OP("FloorMod")
554 .Input("x: T")
555 .Input("y: T")
556 .Output("z: T")
557 .Attr("T: {int32, int64, bfloat16, half, float, double}")
558 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
559
560 REGISTER_OP("TruncateMod")
561 .Input("x: T")
562 .Input("y: T")
563 .Output("z: T")
564 .Attr("T: {int32, int64, bfloat16, half, float, double}")
565 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
566
567 REGISTER_OP("Pow")
568 .Input("x: T")
569 .Input("y: T")
570 .Output("z: T")
571 .Attr(
572 "T: {bfloat16, float, half, double, int32, int64, complex64, "
573 "complex128}")
574 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
575
576 REGISTER_OP("Igammac")
577 .Input("a: T")
578 .Input("x: T")
579 .Output("z: T")
580 .Attr("T: {float, double}")
581 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
582
583 REGISTER_OP("Igamma")
584 .Input("a: T")
585 .Input("x: T")
586 .Output("z: T")
587 .Attr("T: {float, double}")
588 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
589
590 REGISTER_OP("IgammaGradA")
591 .Input("a: T")
592 .Input("x: T")
593 .Output("z: T")
594 .Attr("T: {float, double}")
595 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
596
597 REGISTER_OP("Zeta")
598 .Input("x: T")
599 .Input("q: T")
600 .Output("z: T")
601 .Attr("T: {float, double}")
602 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
603
604 REGISTER_OP("Polygamma")
605 .Input("a: T")
606 .Input("x: T")
607 .Output("z: T")
608 .Attr("T: {float, double}")
609 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
610
611 REGISTER_OP("Atan2")
612 .Input("y: T")
613 .Input("x: T")
614 .Output("z: T")
615 .Attr("T: {bfloat16, half, float, double}")
616 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
617
618 REGISTER_OP("Betainc")
619 .Input("a: T")
620 .Input("b: T")
621 .Input("x: T")
622 .Output("z: T")
623 .Attr("T: {float, double}")
__anonb22bfa860302(InferenceContext* c) 624 .SetShapeFn([](InferenceContext* c) {
625 const int num_inputs = 3;
626 ShapeHandle output = c->UnknownShape();
627 int num_scalars = 0;
628 ShapeHandle some_non_scalar;
629 for (int i = 0; i < num_inputs; ++i) {
630 ShapeHandle in = c->input(i);
631 if (!c->RankKnown(in)) {
632 some_non_scalar = in;
633 // An input with unknown rank could be either a scalar (to be
634 // broadcast) or some other shape.
635 } else if (c->Rank(in) == 0) {
636 // Input is a scalar, it will be broadcast to the output shape.
637 ++num_scalars;
638 } else {
639 TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
640 some_non_scalar = output;
641 }
642 }
643
644 if (num_scalars == num_inputs - 1) {
645 // If all but one input is known to be a scalar, then output is the
646 // remaining input.
647 output = some_non_scalar;
648 } else if (num_scalars == num_inputs) {
649 // If all are scalars, output is scalar; pick the first one arbitrarily.
650 output = c->input(0);
651 }
652
653 c->set_output(0, output);
654 return Status::OK();
655 });
656
657 // --------------------------------------------------------------------------
658
659 // Declares cwise binary comparison operations signature: 't, 't -> bool,
660 // where 't has a natural total order.
661 #define COMPARISON() \
662 Input("x: T") \
663 .Input("y: T") \
664 .Output("z: bool") \
665 .Attr("T: realnumbertype") \
666 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
667
668 REGISTER_OP("Less").COMPARISON();
669
670 REGISTER_OP("LessEqual").COMPARISON();
671
672 REGISTER_OP("Greater").COMPARISON();
673
674 REGISTER_OP("GreaterEqual").COMPARISON();
675
676 #undef COMPARISON
677
678 // --------------------------------------------------------------------------
679
680 #define EQUALITY_COMPARISON() \
681 Input("x: T") \
682 .Input("y: T") \
683 .Output("z: bool") \
684 .SetIsCommutative() \
685 .Attr( \
686 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
687 "int64, complex64, quint8, qint8, qint32, string, bool, " \
688 "complex128}") \
689 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
690
691 REGISTER_OP("Equal").EQUALITY_COMPARISON();
692
693 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
694
695 #undef EQUALITY_COMPARISON
696
697 REGISTER_OP("ApproximateEqual")
698 .Input("x: T")
699 .Input("y: T")
700 .Output("z: bool")
701 .SetIsCommutative()
702 .Attr("T: numbertype")
703 .Attr("tolerance: float = 0.00001")
__anonb22bfa860402(InferenceContext* c) 704 .SetShapeFn([](InferenceContext* c) {
705 // The inputs 'x' and 'y' must have the same shape.
706 ShapeHandle data_x = c->input(0);
707 ShapeHandle data_y = c->input(1);
708 TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
709 return shape_inference::UnchangedShape(c);
710 });
711
712 // --------------------------------------------------------------------------
713
714 REGISTER_OP("LogicalNot")
715 .Input("x: bool")
716 .Output("y: bool")
717 .SetShapeFn(shape_inference::UnchangedShape);
718
719 #define BINARY_LOGICAL() \
720 Input("x: bool") \
721 .Input("y: bool") \
722 .Output("z: bool") \
723 .SetIsCommutative() \
724 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
725
726 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
727
728 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
729
730 #undef BINARY_LOGICAL
731
732 // --------------------------------------------------------------------------
733
734 REGISTER_OP("Select")
735 .Input("condition: bool")
736 .Input("t: T")
737 .Input("e: T")
738 .Output("output: T")
739 .Attr("T: type")
__anonb22bfa860502(InferenceContext* c) 740 .SetShapeFn([](InferenceContext* c) {
741 auto* handle_data_1 = c->input_handle_shapes_and_types(1);
742 auto* handle_data_2 = c->input_handle_shapes_and_types(2);
743 // Merge handle shape and dtype if applicable.
744 if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
745 const auto size = handle_data_1->size();
746 std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
747 if (size != handle_data_2->size()) {
748 return errors::InvalidArgument(
749 "Trying to merge handles pointing to different numbers of "
750 "tensors.");
751 }
752
753 for (int i = 0; i < size; ++i) {
754 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
755 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
756 if (s1.dtype != s2.dtype) {
757 // TODO(apassos) resolve this in the manner of b/32476923
758 return errors::InvalidArgument(
759 "Trying to merge handles pointing to different dtypes.");
760 }
761 merged_handle_data[i].dtype = s1.dtype;
762 TF_RETURN_IF_ERROR(
763 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
764 }
765
766 c->set_output_handle_shapes_and_types(0, merged_handle_data);
767 }
768
769 // The inputs 'then' and 'else' must have the same shape.
770 ShapeHandle data = c->input(1);
771 ShapeHandle other = c->input(2);
772 TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
773
774 // The input 'cond' must either have the same shape as 'then' and
775 // 'else', or be a vector if 'then' and 'else' are at least vectors.
776 ShapeHandle cond = c->input(0);
777
778 if (!c->RankKnown(cond) || !c->RankKnown(data)) {
779 c->set_output(0, data);
780 return Status::OK();
781 }
782
783 // rank of shape and data is known.
784
785 const int32 cond_rank = c->Rank(cond);
786 const int32 data_rank = c->Rank(data);
787
788 if (cond_rank == 0) {
789 // The rank of 'cond' is a scalar.
790 // t and e can have any shape.
791 c->set_output(0, data);
792 return Status::OK();
793 }
794
795 if (cond_rank != 1) {
796 // If 'cond' is not a vector, and not a scalar,
797 // then shape must match 'then' and 'else'
798 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
799 c->set_output(0, data);
800 return Status::OK();
801 }
802
803 if (data_rank == 0) {
804 // if 'then' and 'else' are scalar also the cond must be
805 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
806 c->set_output(0, data);
807 return Status::OK();
808 }
809
810 if (cond_rank == 1) {
811 // if the cond is a vector and the 'then' is not a scalar,
812 // the first dimension of 'then' and 'else'
813 TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
814 c->set_output(0, data);
815 return Status::OK();
816 }
817
818 c->set_output(0, data);
819
820 return Status::OK();
821 });
822
823 // --------------------------------------------------------------------------
824
825 REGISTER_OP("MatMul")
826 .Input("a: T")
827 .Input("b: T")
828 .Output("product: T")
829 .Attr("transpose_a: bool = false")
830 .Attr("transpose_b: bool = false")
831 .Attr(
832 "T: {bfloat16, half, float, double, int32, int64, complex64, "
833 "complex128}")
834 .SetShapeFn(shape_inference::MatMulShape);
835
836 REGISTER_OP("SparseMatMul")
837 .Input("a: Ta")
838 .Input("b: Tb")
839 .Output("product: float")
840 .Attr("transpose_a: bool = false")
841 .Attr("transpose_b: bool = false")
842 .Attr("a_is_sparse: bool = false")
843 .Attr("b_is_sparse: bool = false")
844 .Attr("Ta: {float, bfloat16} = DT_FLOAT")
845 .Attr("Tb: {float, bfloat16} = DT_FLOAT")
846 .SetShapeFn(shape_inference::MatMulShape);
847
848 // --------------------------------------------------------------------------
849
850 // For operations where the output is a reduction function along some
851 // dimensions of the input.
852 REGISTER_OP("Sum")
853 .Input("input: T")
854 .Input("reduction_indices: Tidx")
855 .Output("output: T")
856 .Attr("keep_dims: bool = false")
857 .Attr("T: numbertype")
858 .Attr("Tidx: {int32, int64} = DT_INT32")
859 .SetShapeFn(shape_inference::ReductionShape);
860
861 REGISTER_OP("EuclideanNorm")
862 .Input("input: T")
863 .Input("reduction_indices: Tidx")
864 .Output("output: T")
865 .Attr("keep_dims: bool = false")
866 .Attr("T: numbertype")
867 .Attr("Tidx: {int32, int64} = DT_INT32")
868 .SetShapeFn(shape_inference::ReductionShape);
869
870 REGISTER_OP("Mean")
871 .Input("input: T")
872 .Input("reduction_indices: Tidx")
873 .Output("output: T")
874 .Attr("keep_dims: bool = false")
875 .Attr("T: numbertype")
876 .Attr("Tidx: {int32, int64} = DT_INT32")
877 .SetShapeFn(shape_inference::ReductionShape);
878
879 REGISTER_OP("Prod")
880 .Input("input: T")
881 .Input("reduction_indices: Tidx")
882 .Output("output: T")
883 .Attr("keep_dims: bool = false")
884 .Attr("T: numbertype")
885 .Attr("Tidx: {int32, int64} = DT_INT32")
886 .SetShapeFn(shape_inference::ReductionShape);
887
888 REGISTER_OP("Min")
889 .Input("input: T")
890 .Input("reduction_indices: Tidx")
891 .Output("output: T")
892 .Attr("keep_dims: bool = false")
893 .Attr("T: numbertype")
894 .Attr("Tidx: {int32, int64} = DT_INT32")
895 .SetShapeFn(shape_inference::ReductionShape);
896
897 REGISTER_OP("Max")
898 .Input("input: T")
899 .Input("reduction_indices: Tidx")
900 .Output("output: T")
901 .Attr("keep_dims: bool = false")
902 .Attr("T: numbertype")
903 .Attr("Tidx: {int32, int64} = DT_INT32")
904 .SetShapeFn(shape_inference::ReductionShape);
905
906 namespace {
907
ArgOpShape(shape_inference::InferenceContext * c)908 Status ArgOpShape(shape_inference::InferenceContext* c) {
909 ShapeHandle dimension_shape;
910 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
911
912 ShapeHandle input_shape = c->input(0);
913 if (!c->RankKnown(input_shape)) {
914 return shape_inference::UnknownShape(c);
915 }
916
917 const int32 input_rank = c->Rank(input_shape);
918 if (input_rank <= 1) {
919 // Reducing a scalar/vector must return a scalar.
920 return shape_inference::ScalarShape(c);
921 }
922
923 const Tensor* dim_t = c->input_tensor(1);
924 if (dim_t == nullptr) {
925 // We don't know the value of the dimension, but we
926 // know the rank of the input, so return the correct
927 // rank with unknown dimensions.
928 std::vector<DimensionHandle> dims(input_rank - 1);
929 for (int i = 0; i < dims.size(); ++i) {
930 dims[i] = c->UnknownDim();
931 }
932
933 c->set_output(0, c->MakeShape(dims));
934 return Status::OK();
935 }
936
937 int64 dimension_val;
938 if (dim_t->dtype() == DT_INT32) {
939 dimension_val = dim_t->scalar<int32>()();
940 } else {
941 dimension_val = dim_t->scalar<int64>()();
942 }
943
944 int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
945 if (axis < 0 || axis >= input_rank) {
946 return errors::InvalidArgument(
947 "Dimension (", dimension_val, ") must be in the range [", -input_rank,
948 ", ", input_rank, "), where ", input_rank,
949 " is the number of dimensions in the input.");
950 }
951
952 // Return the input shape without the dimension being reduced.
953 std::vector<DimensionHandle> dims;
954 for (int i = 0; i < input_rank; ++i) {
955 if (axis != i) {
956 dims.emplace_back(c->Dim(input_shape, i));
957 }
958 }
959 c->set_output(0, c->MakeShape(dims));
960 return Status::OK();
961 }
962
963 } // namespace
964
965 REGISTER_OP("ArgMax")
966 .Input("input: T")
967 .Input("dimension: Tidx")
968 .Output("output: output_type")
969 .Attr("T: numbertype")
970 .Attr("Tidx: {int32, int64} = DT_INT32")
971 .Attr("output_type: {int32, int64} = DT_INT64")
972 .SetShapeFn(ArgOpShape);
973
974 REGISTER_OP("ArgMin")
975 .Input("input: T")
976 .Input("dimension: Tidx")
977 .Output("output: output_type")
978 .Attr("T: numbertype")
979 .Attr("Tidx: {int32, int64} = DT_INT32")
980 .Attr("output_type: {int32, int64} = DT_INT64")
981 .SetShapeFn(ArgOpShape);
982
983 namespace {
984
SegmentReductionShapeFn(InferenceContext * c)985 Status SegmentReductionShapeFn(InferenceContext* c) {
986 ShapeHandle data_shape;
987 ShapeHandle segment_ids_shape;
988 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
989 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
990
991 ShapeHandle subshape;
992 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
993
994 ShapeHandle out;
995 TF_RETURN_IF_ERROR(
996 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
997 c->set_output(0, out);
998 return Status::OK();
999 }
1000
SparseSegmentReductionShapeFn(InferenceContext * c)1001 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
1002 ShapeHandle data_shape;
1003 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1004
1005 ShapeHandle indices_shape;
1006 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1007
1008 ShapeHandle segment_ids_shape;
1009 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1010
1011 // indices and segment_ids should merge cleanly.
1012 ShapeHandle unused;
1013 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1014
1015 ShapeHandle subshape;
1016 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1017
1018 ShapeHandle out;
1019 TF_RETURN_IF_ERROR(
1020 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1021 c->set_output(0, out);
1022 return Status::OK();
1023 }
1024
SparseSegmentReductionGradShapeFn(InferenceContext * c)1025 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
1026 ShapeHandle data_shape;
1027 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1028
1029 ShapeHandle indices_shape;
1030 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1031
1032 // indices and segment_ids should merge cleanly.
1033 ShapeHandle unused;
1034 TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
1035
1036 // output_dim0 should be a scalar
1037 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1038
1039 ShapeHandle subshape;
1040 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1041
1042 const Tensor* dim0 = c->input_tensor(3);
1043 ShapeHandle dim0_shape;
1044 if (dim0 == nullptr) {
1045 // We don't have the value at inference time, so the output
1046 // shape is unknown.
1047 dim0_shape = c->Vector(InferenceContext::kUnknownDim);
1048 } else {
1049 auto dim0_value = dim0->scalar<int32>()();
1050 if (dim0_value < 0) {
1051 return errors::InvalidArgument(
1052 "Cannot specify a negative value for output_dim0");
1053 }
1054 dim0_shape = c->Vector(dim0_value);
1055 }
1056
1057 ShapeHandle out;
1058 TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
1059 c->set_output(0, out);
1060 return Status::OK();
1061 }
1062
SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext * c)1063 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
1064 ShapeHandle data_shape;
1065 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1066
1067 ShapeHandle indices_shape;
1068 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1069
1070 ShapeHandle segment_ids_shape;
1071 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1072
1073 ShapeHandle num_segments_shape;
1074 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
1075
1076 // indices and segment_ids should merge cleanly.
1077 ShapeHandle unused;
1078 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1079
1080 ShapeHandle subshape;
1081 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1082
1083 ShapeHandle out;
1084 const Tensor* dim0 = c->input_tensor(3);
1085 if (dim0 == nullptr) {
1086 // We don't have the value at inference time, so the output
1087 // shape is unknown.
1088 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
1089 subshape, &out));
1090 } else {
1091 auto dim0_value = dim0->scalar<int32>()();
1092 if (dim0_value < 0) {
1093 return errors::InvalidArgument(
1094 "Cannot specify a negative value for num_segments");
1095 }
1096 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
1097 }
1098 c->set_output(0, out);
1099 return Status::OK();
1100 }
1101
UnsortedSegmentReductionShapeFn(InferenceContext * c)1102 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
1103 ShapeHandle s_data = c->input(0);
1104 ShapeHandle s_segment_ids = c->input(1);
1105 ShapeHandle s_num_segments = c->input(2);
1106 TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
1107
1108 ShapeHandle out;
1109
1110 // Leading dimensions of data must be compatible with dimensions of
1111 // <s_segment_ids>.
1112 if (c->RankKnown(s_segment_ids)) {
1113 TF_RETURN_IF_ERROR(
1114 c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
1115
1116 // Get the value of the num_segments input tensor.
1117 DimensionHandle num_segments_dim;
1118 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
1119
1120 // Output is {segment_id_rank} + s_data[segment_id_rank:].
1121 ShapeHandle s_data_suffix;
1122 TF_RETURN_IF_ERROR(
1123 c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
1124 TF_RETURN_IF_ERROR(
1125 c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
1126 } else {
1127 out = c->UnknownShape();
1128 }
1129 c->set_output(0, out);
1130 return Status::OK();
1131 }
1132 } // namespace
1133
1134 REGISTER_OP("SegmentSum")
1135 .Input("data: T")
1136 .Input("segment_ids: Tindices")
1137 .Output("output: T")
1138 .Attr("T: numbertype")
1139 .Attr("Tindices: {int32,int64}")
1140 .SetShapeFn(SegmentReductionShapeFn);
1141
1142 REGISTER_OP("SegmentMean")
1143 .Input("data: T")
1144 .Input("segment_ids: Tindices")
1145 .Output("output: T")
1146 .Attr("T: numbertype")
1147 .Attr("Tindices: {int32,int64}")
1148 .SetShapeFn(SegmentReductionShapeFn);
1149
1150 REGISTER_OP("SegmentProd")
1151 .Input("data: T")
1152 .Input("segment_ids: Tindices")
1153 .Output("output: T")
1154 .Attr("T: numbertype")
1155 .Attr("Tindices: {int32,int64}")
1156 .SetShapeFn(SegmentReductionShapeFn);
1157
1158 REGISTER_OP("SegmentMin")
1159 .Input("data: T")
1160 .Input("segment_ids: Tindices")
1161 .Output("output: T")
1162 .Attr("T: realnumbertype")
1163 .Attr("Tindices: {int32,int64}")
1164 .SetShapeFn(SegmentReductionShapeFn);
1165
1166 REGISTER_OP("SegmentMax")
1167 .Input("data: T")
1168 .Input("segment_ids: Tindices")
1169 .Output("output: T")
1170 .Attr("T: realnumbertype")
1171 .Attr("Tindices: {int32,int64}")
1172 .SetShapeFn(SegmentReductionShapeFn);
1173
1174 REGISTER_OP("UnsortedSegmentSum")
1175 .Input("data: T")
1176 .Input("segment_ids: Tindices")
1177 .Input("num_segments: Tnumsegments")
1178 .Output("output: T")
1179 .Attr("T: numbertype")
1180 .Attr("Tindices: {int32,int64}")
1181 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1182 .SetShapeFn(UnsortedSegmentReductionShapeFn);
1183
1184 REGISTER_OP("UnsortedSegmentMax")
1185 .Input("data: T")
1186 .Input("segment_ids: Tindices")
1187 .Input("num_segments: Tnumsegments")
1188 .Output("output: T")
1189 .Attr("T: realnumbertype")
1190 .Attr("Tindices: {int32,int64}")
1191 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1192 .SetShapeFn(UnsortedSegmentReductionShapeFn);
1193
1194 REGISTER_OP("UnsortedSegmentMin")
1195 .Input("data: T")
1196 .Input("segment_ids: Tindices")
1197 .Input("num_segments: Tnumsegments")
1198 .Output("output: T")
1199 .Attr("T: realnumbertype")
1200 .Attr("Tindices: {int32,int64}")
1201 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1202 .SetShapeFn(UnsortedSegmentReductionShapeFn);
1203
1204 REGISTER_OP("UnsortedSegmentProd")
1205 .Input("data: T")
1206 .Input("segment_ids: Tindices")
1207 .Input("num_segments: Tnumsegments")
1208 .Output("output: T")
1209 .Attr("T: numbertype")
1210 .Attr("Tindices: {int32,int64}")
1211 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1212 .SetShapeFn(UnsortedSegmentReductionShapeFn);
1213
1214 REGISTER_OP("SparseSegmentSum")
1215 .Input("data: T")
1216 .Input("indices: Tidx")
1217 .Input("segment_ids: int32")
1218 .Output("output: T")
1219 .Attr("T: realnumbertype")
1220 .Attr("Tidx: {int32, int64} = DT_INT32")
1221 .SetShapeFn(SparseSegmentReductionShapeFn);
1222
1223 REGISTER_OP("SparseSegmentSumWithNumSegments")
1224 .Input("data: T")
1225 .Input("indices: Tidx")
1226 .Input("segment_ids: int32")
1227 .Input("num_segments: Tnumsegments")
1228 .Output("output: T")
1229 .Attr("T: realnumbertype")
1230 .Attr("Tidx: {int32, int64} = DT_INT32")
1231 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1232 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1233
1234 REGISTER_OP("SparseSegmentMean")
1235 .Input("data: T")
1236 .Input("indices: Tidx")
1237 .Input("segment_ids: int32")
1238 .Output("output: T")
1239 .Attr("T: {float, double}")
1240 .Attr("Tidx: {int32, int64} = DT_INT32")
1241 .SetShapeFn(SparseSegmentReductionShapeFn);
1242
1243 REGISTER_OP("SparseSegmentMeanWithNumSegments")
1244 .Input("data: T")
1245 .Input("indices: Tidx")
1246 .Input("segment_ids: int32")
1247 .Input("num_segments: Tnumsegments")
1248 .Output("output: T")
1249 .Attr("T: {float, double}")
1250 .Attr("Tidx: {int32, int64} = DT_INT32")
1251 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1252 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1253
1254 REGISTER_OP("SparseSegmentMeanGrad")
1255 .Input("grad: T")
1256 .Input("indices: Tidx")
1257 .Input("segment_ids: int32")
1258 .Input("output_dim0: int32")
1259 .Output("output: T")
1260 .Attr("T: {float, double}")
1261 .Attr("Tidx: {int32, int64} = DT_INT32")
1262 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1263
1264 REGISTER_OP("SparseSegmentSqrtN")
1265 .Input("data: T")
1266 .Input("indices: Tidx")
1267 .Input("segment_ids: int32")
1268 .Output("output: T")
1269 .Attr("T: {float, double}")
1270 .Attr("Tidx: {int32, int64} = DT_INT32")
1271 .SetShapeFn(SparseSegmentReductionShapeFn);
1272
1273 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
1274 .Input("data: T")
1275 .Input("indices: Tidx")
1276 .Input("segment_ids: int32")
1277 .Input("num_segments: Tnumsegments")
1278 .Output("output: T")
1279 .Attr("T: {float, double}")
1280 .Attr("Tidx: {int32, int64} = DT_INT32")
1281 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1282 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1283
1284 REGISTER_OP("SparseSegmentSqrtNGrad")
1285 .Input("grad: T")
1286 .Input("indices: Tidx")
1287 .Input("segment_ids: int32")
1288 .Input("output_dim0: int32")
1289 .Output("output: T")
1290 .Attr("T: {float, double}")
1291 .Attr("Tidx: {int32, int64} = DT_INT32")
1292 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1293
1294 REGISTER_OP("All")
1295 .Input("input: bool")
1296 .Input("reduction_indices: Tidx")
1297 .Output("output: bool")
1298 .Attr("keep_dims: bool = false")
1299 .Attr("Tidx: {int32, int64} = DT_INT32")
1300 .SetShapeFn(shape_inference::ReductionShape);
1301
1302 REGISTER_OP("Any")
1303 .Input("input: bool")
1304 .Input("reduction_indices: Tidx")
1305 .Attr("keep_dims: bool = false")
1306 .Output("output: bool")
1307 .Attr("Tidx: {int32, int64} = DT_INT32")
1308 .SetShapeFn(shape_inference::ReductionShape);
1309
1310 // --------------------------------------------------------------------------
1311
1312 namespace {
1313
1314 template <typename T>
RangeSize(const Tensor * start_t,const Tensor * limit_t,const Tensor * delta_t,InferenceContext * const c)1315 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
1316 const Tensor* delta_t, InferenceContext* const c) {
1317 T start = start_t->scalar<T>()();
1318 T limit = limit_t->scalar<T>()();
1319 T delta = delta_t->scalar<T>()();
1320 if (start > limit && delta > 0) {
1321 return errors::InvalidArgument(
1322 "Requires start <= limit when delta > 0: ", start, "/", limit);
1323 }
1324 if (start < limit && delta < 0) {
1325 return errors::InvalidArgument(
1326 "Requires start >= limit when delta < 0: ", start, "/", limit);
1327 }
1328 if (delta == 0) {
1329 return errors::InvalidArgument("Requires delta != 0");
1330 }
1331
1332 int64 size =
1333 (std::is_integral<T>::value
1334 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
1335 : std::ceil(std::abs((limit - start) / delta)));
1336 c->set_output(0, c->Vector(size));
1337 return Status::OK();
1338 }
1339
1340 } // namespace
1341
1342 REGISTER_OP("Range")
1343 .Input("start: Tidx")
1344 .Input("limit: Tidx")
1345 .Input("delta: Tidx")
1346 .Output("output: Tidx")
1347 .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32")
__anonb22bfa860902(InferenceContext* c) 1348 .SetShapeFn([](InferenceContext* c) {
1349 ShapeHandle unused;
1350 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1351 " for 'start'");
1352 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1353 " for 'limit'");
1354 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1355 " for 'delta'");
1356 const Tensor* start_t = c->input_tensor(0);
1357 const Tensor* limit_t = c->input_tensor(1);
1358 const Tensor* delta_t = c->input_tensor(2);
1359 DataType dtype;
1360 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1361 if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
1362 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1363 return Status::OK();
1364 }
1365 if (dtype == DT_INT32) {
1366 return RangeSize<int32>(start_t, limit_t, delta_t, c);
1367 } else if (dtype == DT_INT64) {
1368 return RangeSize<int64>(start_t, limit_t, delta_t, c);
1369 } else if (dtype == DT_FLOAT) {
1370 return RangeSize<float>(start_t, limit_t, delta_t, c);
1371 } else {
1372 return RangeSize<double>(start_t, limit_t, delta_t, c);
1373 }
1374 return Status::OK();
1375 });
1376
1377 REGISTER_OP("LinSpace")
1378 .Input("start: T")
1379 .Input("stop: T")
1380 .Input("num: Tidx")
1381 .Output("output: T")
1382 .Attr("T: {bfloat16, float, double}")
1383 .Attr("Tidx: {int32, int64} = DT_INT32")
__anonb22bfa860a02(InferenceContext* c) 1384 .SetShapeFn([](InferenceContext* c) {
1385 ShapeHandle unused;
1386 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1387 " for 'start'");
1388 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1389 " for 'stop'");
1390 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1391 " for 'num'");
1392 const Tensor* num_t = c->input_tensor(2);
1393 if (num_t == nullptr) {
1394 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1395 return Status::OK();
1396 }
1397
1398 int64 num;
1399 if (num_t->dtype() == DT_INT32) {
1400 num = num_t->scalar<int32>()();
1401 } else {
1402 num = num_t->scalar<int64>()();
1403 }
1404 if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
1405 c->set_output(0, c->Vector(num));
1406 return Status::OK();
1407 });
1408
1409 REGISTER_OP("Complex")
1410 .Input("real: T")
1411 .Input("imag: T")
1412 .Output("out: Tout")
1413 .Attr("T: {float, double} = DT_FLOAT")
1414 .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
1415 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1416
1417 REGISTER_OP("Real")
1418 .Input("input: T")
1419 .Output("output: Tout")
1420 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1421 .Attr("Tout: {float, double} = DT_FLOAT")
1422 .SetShapeFn(shape_inference::UnchangedShape);
1423
1424 REGISTER_OP("Imag")
1425 .Input("input: T")
1426 .Output("output: Tout")
1427 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1428 .Attr("Tout: {float, double} = DT_FLOAT")
1429 .SetShapeFn(shape_inference::UnchangedShape);
1430
1431 REGISTER_OP("Angle")
1432 .Input("input: T")
1433 .Output("output: Tout")
1434 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1435 .Attr("Tout: {float, double} = DT_FLOAT")
1436 .SetShapeFn(shape_inference::UnchangedShape);
1437
1438 REGISTER_OP("Conj")
1439 .Input("input: T")
1440 .Output("output: T")
1441 .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
__anonb22bfa860b02(InferenceContext* c) 1442 .SetShapeFn([](InferenceContext* c) {
1443 c->set_output(0, c->input(0));
1444 auto* handle_data = c->input_handle_shapes_and_types(0);
1445 if (handle_data != nullptr) {
1446 c->set_output_handle_shapes_and_types(0, *handle_data);
1447 }
1448 return Status::OK();
1449 });
1450
1451 // --------------------------------------------------------------------------
1452
1453 REGISTER_OP("Cross")
1454 .Input("a: T")
1455 .Input("b: T")
1456 .Output("product: T")
1457 .Attr("T: realnumbertype")
__anonb22bfa860c02(InferenceContext* c) 1458 .SetShapeFn([](InferenceContext* c) {
1459 ShapeHandle a_shape;
1460 ShapeHandle b_shape;
1461 // * Input rank >= 1.
1462 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
1463 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
1464
1465 // * Both inputs have the same shape.
1466 TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
1467
1468 // * input_shape[-1] == 3.
1469 if (c->RankKnown(a_shape)) {
1470 int rank = c->Rank(a_shape);
1471 auto dim = c->Dim(a_shape, rank - 1);
1472 TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
1473 }
1474 c->set_output(0, a_shape);
1475 return Status::OK();
1476 });
1477
1478 // --------------------------------------------------------------------------
1479
1480 REGISTER_OP("HistogramFixedWidth")
1481 .Input("values: T")
1482 .Input("value_range: T")
1483 .Input("nbins: int32")
1484 .Output("out: dtype")
1485 .Attr("T: {int32, int64, float32, float64}")
1486 .Attr("dtype: {int32, int64} = DT_INT32")
__anonb22bfa860d02(InferenceContext* c) 1487 .SetShapeFn([](InferenceContext* c) {
1488 // value_range should be a vector.
1489 ShapeHandle value_range_shape;
1490 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
1491 // value_range should have two elements.
1492 DimensionHandle unused;
1493 TF_RETURN_IF_ERROR(
1494 c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
1495 // nbins should be a scalar.
1496 ShapeHandle nbins_shape;
1497 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
1498
1499 // If nbins is available, set the shape from nbins.
1500 const Tensor* nbins_input = c->input_tensor(2);
1501 if (nbins_input != nullptr) {
1502 int64 nbins;
1503 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
1504 // nbins has to be positive.
1505 if (nbins <= 0) {
1506 return errors::InvalidArgument("Requires nbins > 0: ", nbins);
1507 }
1508 c->set_output(0, c->Vector(nbins));
1509 } else {
1510 c->set_output(0, c->UnknownShapeOfRank(1));
1511 }
1512 return Status::OK();
1513 });
1514
1515 REGISTER_OP("Bincount")
1516 .Input("arr: int32")
1517 .Input("size: int32")
1518 .Input("weights: T")
1519 .Attr("T: {int32, int64, float32, float64}")
1520 .Output("bins: T")
__anonb22bfa860e02(InferenceContext* c) 1521 .SetShapeFn([](InferenceContext* c) {
1522 ShapeHandle unused;
1523 // The input `size` must be a scalar.
1524 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1525
1526 const Tensor* size_tensor = c->input_tensor(1);
1527 if (size_tensor == nullptr) {
1528 // Return unknown shape if size is not known.
1529 c->set_output(0, c->UnknownShapeOfRank(1));
1530 return Status::OK();
1531 }
1532
1533 // Return `[size]` shape if size is known.
1534 int32 size_val = size_tensor->scalar<int32>()();
1535 if (size_val < 0) {
1536 return errors::InvalidArgument("size (", size_val,
1537 ") must be non-negative");
1538 }
1539 c->set_output(0, c->MakeShape({size_val}));
1540 return Status::OK();
1541 });
1542
1543 REGISTER_OP("Cumsum")
1544 .Input("x: T")
1545 .Input("axis: Tidx")
1546 .Attr("exclusive: bool = false")
1547 .Attr("reverse: bool = false")
1548 .Output("out: T")
1549 .Attr("T: numbertype")
1550 .Attr("Tidx: {int32, int64} = DT_INT32")
1551 .SetShapeFn(shape_inference::UnchangedShape);
1552
1553 REGISTER_OP("Cumprod")
1554 .Input("x: T")
1555 .Input("axis: Tidx")
1556 .Attr("exclusive: bool = false")
1557 .Attr("reverse: bool = false")
1558 .Output("out: T")
1559 .Attr("T: numbertype")
1560 .Attr("Tidx: {int32, int64} = DT_INT32")
1561 .SetShapeFn(shape_inference::UnchangedShape);
1562
1563 REGISTER_OP("QuantizedMatMul")
1564 .Input("a: T1")
1565 .Input("b: T2")
1566 .Input("min_a: float")
1567 .Input("max_a: float")
1568 .Input("min_b: float")
1569 .Input("max_b: float")
1570 .Output("out: Toutput")
1571 .Output("min_out: float")
1572 .Output("max_out: float")
1573 .Attr("T1: quantizedtype")
1574 .Attr("T2: quantizedtype")
1575 .Attr("Toutput: quantizedtype = DT_QINT32")
1576 .Attr("transpose_a: bool = false")
1577 .Attr("transpose_b: bool = false")
1578 .Attr("Tactivation: quantizedtype = DT_QUINT8")
__anonb22bfa860f02(InferenceContext* c) 1579 .SetShapeFn([](InferenceContext* c) {
1580 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1581 ShapeHandle unused;
1582 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1583 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1584 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1585 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1586
1587 c->set_output(1, c->Scalar());
1588 c->set_output(2, c->Scalar());
1589 return Status::OK();
1590 });
1591
1592 REGISTER_OP("QuantizedMul")
1593 .Input("x: T1")
1594 .Input("y: T2")
1595 .Input("min_x: float")
1596 .Input("max_x: float")
1597 .Input("min_y: float")
1598 .Input("max_y: float")
1599 .Output("z: Toutput")
1600 .Output("min_z: float")
1601 .Output("max_z: float")
1602 .Attr("T1: quantizedtype")
1603 .Attr("T2: quantizedtype")
1604 .Attr("Toutput: quantizedtype = DT_QINT32")
1605 .SetIsCommutative()
__anonb22bfa861002(InferenceContext* c) 1606 .SetShapeFn([](InferenceContext* c) {
1607 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1608 c->set_output(1, c->Scalar());
1609 c->set_output(2, c->Scalar());
1610 return Status::OK();
1611 });
1612
1613 REGISTER_OP("QuantizedAdd")
1614 .Input("x: T1")
1615 .Input("y: T2")
1616 .Input("min_x: float")
1617 .Input("max_x: float")
1618 .Input("min_y: float")
1619 .Input("max_y: float")
1620 .Output("z: Toutput")
1621 .Output("min_z: float")
1622 .Output("max_z: float")
1623 .Attr("T1: quantizedtype")
1624 .Attr("T2: quantizedtype")
1625 .Attr("Toutput: quantizedtype = DT_QINT32")
1626 .SetIsCommutative()
__anonb22bfa861102(InferenceContext* c) 1627 .SetShapeFn([](InferenceContext* c) {
1628 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1629 // min_x, max_x, min_y, max_y should be scalar.
1630 ShapeHandle unused;
1631 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1632 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1633 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1634 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1635
1636 c->set_output(1, c->Scalar());
1637 c->set_output(2, c->Scalar());
1638 return Status::OK();
1639 });
1640
1641 REGISTER_OP("QuantizeDownAndShrinkRange")
1642 .Input("input: Tinput")
1643 .Input("input_min: float")
1644 .Input("input_max: float")
1645 .Output("output: out_type")
1646 .Output("output_min: float")
1647 .Output("output_max: float")
1648 .Attr("Tinput: quantizedtype")
1649 .Attr("out_type: quantizedtype")
__anonb22bfa861202(InferenceContext* c) 1650 .SetShapeFn([](InferenceContext* c) {
1651 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1652 ShapeHandle unused;
1653 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1654 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1655 c->set_output(1, c->Scalar());
1656 c->set_output(2, c->Scalar());
1657 return Status::OK();
1658 });
1659
1660 REGISTER_OP("Requantize")
1661 .Input("input: Tinput")
1662 .Input("input_min: float")
1663 .Input("input_max: float")
1664 .Input("requested_output_min: float")
1665 .Input("requested_output_max: float")
1666 .Output("output: out_type")
1667 .Output("output_min: float")
1668 .Output("output_max: float")
1669 .Attr("Tinput: quantizedtype")
1670 .Attr("out_type: quantizedtype")
__anonb22bfa861302(InferenceContext* c) 1671 .SetShapeFn([](InferenceContext* c) {
1672 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1673 ShapeHandle unused;
1674 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1675 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1676 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1677 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1678 c->set_output(1, c->Scalar());
1679 c->set_output(2, c->Scalar());
1680 return Status::OK();
1681 });
1682
1683 REGISTER_OP("CompareAndBitpack")
1684 .Input("input: T")
1685 .Input("threshold: T")
1686 .Output("output: uint8")
1687 .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}")
__anonb22bfa861402(InferenceContext* c) 1688 .SetShapeFn([](InferenceContext* c) {
1689 ShapeHandle input;
1690 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1691 ShapeHandle unused;
1692 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1693 ShapeHandle output = input;
1694 if (c->RankKnown(input)) {
1695 int rank = c->Rank(input);
1696 auto inner_dim = c->Dim(input, rank - 1);
1697 DimensionHandle inferred_dim;
1698 TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8,
1699 /* evenly_divisible */ true,
1700 &inferred_dim));
1701 TF_RETURN_IF_ERROR(
1702 c->ReplaceDim(output, rank - 1, inferred_dim, &output));
1703 }
1704 c->set_output(0, output);
1705
1706 return Status::OK();
1707 });
1708
1709 REGISTER_OP("RequantizationRange")
1710 .Input("input: Tinput")
1711 .Input("input_min: float")
1712 .Input("input_max: float")
1713 .Output("output_min: float")
1714 .Output("output_max: float")
1715 .Attr("Tinput: quantizedtype")
__anonb22bfa861502(InferenceContext* c) 1716 .SetShapeFn([](InferenceContext* c) {
1717 ShapeHandle unused;
1718 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1719 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1720 c->set_output(0, c->Scalar());
1721 c->set_output(1, c->Scalar());
1722 return Status::OK();
1723 });
1724
1725 // --------------------------------------------------------------------------
1726
1727 REGISTER_OP("Bucketize")
1728 .Input("input: T")
1729 .Output("output: int32")
1730 .Attr("T: {int32, int64, float, double}")
1731 .Attr("boundaries: list(float)")
1732 .SetShapeFn(shape_inference::UnchangedShape);
1733
1734 REGISTER_OP("ClipByValue")
1735 .Input("t: T")
1736 .Input("clip_value_min: T")
1737 .Input("clip_value_max: T")
1738 .Output("output: T")
1739 .Attr("T: numbertype")
1740 .SetShapeFn(shape_inference::UnchangedShape);
1741
1742 #ifdef INTEL_MKL
1743 REGISTER_OP("_MklAddN")
1744 .Input("inputs: N * T")
1745 .Input("mkl_input: N * uint8")
1746 .Output("sum: T")
1747 .Output("mkl_sum: uint8")
1748 .Attr("N: int >= 1")
1749 .Attr("T: numbertype")
1750 .SetIsCommutative()
1751 .SetIsAggregate()
__anonb22bfa861602(InferenceContext* c) 1752 .SetShapeFn([](InferenceContext* c) {
1753 ShapeHandle cur = c->input(c->num_inputs() - 1);
1754 for (int i = c->num_inputs() - 2; i >= 0; --i) {
1755 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
1756 "From merging shape ", i,
1757 " with other shapes.");
1758 }
1759 c->set_output(0, cur);
1760 return Status::OK();
1761 })
1762 .Doc(R"doc(
1763 Add two input tensors element wise using mkl kernel sum.
1764 inputs: Must all be the same size and shape.
1765 )doc");
1766
1767 #endif // INTEL_MKL
1768
1769 REGISTER_OP("RequantizePerChannel")
1770 .Input("input: T")
1771 .Input("input_min: float")
1772 .Input("input_max: float")
1773 .Input("requested_output_min: float")
1774 .Input("requested_output_max: float")
1775 .Output("output: out_type")
1776 .Output("output_min: float")
1777 .Output("output_max: float")
1778 .Attr("T: quantizedtype = DT_QINT32")
1779 .Attr("out_type: quantizedtype = DT_QUINT8")
__anonb22bfa861702(InferenceContext* c) 1780 .SetShapeFn([](InferenceContext* c) {
1781 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1782 ShapeHandle unused;
1783 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1784 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1785 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1786 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1787 c->set_output(1, c->Scalar());
1788 c->set_output(2, c->Scalar());
1789 return Status::OK();
1790 });
1791 REGISTER_OP("RequantizationRangePerChannel")
1792 .Input("input: T")
1793 .Input("input_min: float")
1794 .Input("input_max: float")
1795 .Output("output_min: float")
1796 .Output("output_max: float")
1797 .Attr("T: quantizedtype = DT_QINT32")
1798 .Attr("clip_value_max: float")
__anonb22bfa861802(InferenceContext* c) 1799 .SetShapeFn([](InferenceContext* c) {
1800 ShapeHandle unused;
1801 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1802 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1803 c->set_output(0, c->Scalar());
1804 c->set_output(1, c->Scalar());
1805 return Status::OK();
1806 });
1807
1808 REGISTER_OP("NextAfter")
1809 .Attr("T: {float64, float32} = DT_FLOAT")
1810 .Input("x1: T")
1811 .Input("x2: T")
1812 .Output("output: T")
1813 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1814
1815 } // namespace tensorflow
1816