• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 REGISTER_OP("AddN")
28     .Input("inputs: N * T")
29     .Output("sum: T")
30     .Attr("N: int >= 1")
31     .Attr("T: {numbertype, variant}")
32     .SetIsCommutative()
33     .SetIsAggregate()
__anonb22bfa860102(InferenceContext* c) 34     .SetShapeFn([](InferenceContext* c) {
35       ShapeHandle cur = c->input(c->num_inputs() - 1);
36       for (int i = c->num_inputs() - 2; i >= 0; --i) {
37         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
38                                         "From merging shape ", i,
39                                         " with other shapes.");
40       }
41       c->set_output(0, cur);
42 
43       DataType dtype;
44       TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
45 
46       if (dtype != DT_VARIANT) {
47         // Exit early if not DT_VARIANT.
48         return Status::OK();
49       } else {
50         // DT_VARIANT shape handle shape inference.  All sizes and dtypes must
51         // be the same; all shapes must be compatible via Merge.
52         std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
53         auto* shapes_and_types =
54             c->input_handle_shapes_and_types(c->num_inputs() - 1);
55         if (shapes_and_types) {
56           cur_shapes_and_types = *shapes_and_types;
57         }
58 
59         for (int i = c->num_inputs() - 2; i >= 0; --i) {
60           auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
61           if (!shapes_and_types && shapes_and_types_i) {
62             // TODO(ebrevdo): Find cases where this happens and fix their shape
63             // inference.  If we are calling AddN on variant types, they should
64             // all have consistent shape_and_type info.
65             shapes_and_types = shapes_and_types_i;
66           } else if (shapes_and_types && shapes_and_types_i) {
67             if (shapes_and_types_i->size() != shapes_and_types->size()) {
68               return errors::InvalidArgument(
69                   "shapes_and_types[", i,
70                   "].size() == ", shapes_and_types_i->size(),
71                   " != shapes_and_types[0].size() == ",
72                   shapes_and_types->size());
73             }
74             for (int j = 0; j < shapes_and_types->size(); ++j) {
75               if (shapes_and_types->at(j).dtype !=
76                   shapes_and_types_i->at(j).dtype) {
77                 return errors::InvalidArgument(
78                     "shapes_and_types[", i, "][", j, "].dtype() == ",
79                     DataTypeString(shapes_and_types_i->at(j).dtype),
80                     " != shapes_and_types[0][", j, "].dtype == ",
81                     DataTypeString(shapes_and_types->at(j).dtype));
82               }
83               TF_RETURN_WITH_CONTEXT_IF_ERROR(
84                   c->Merge(shapes_and_types_i->at(j).shape,
85                            cur_shapes_and_types.at(j).shape,
86                            &cur_shapes_and_types.at(j).shape),
87                   "From merging shapes_and_types[", i, "][", j, "].shape with ",
88                   "shapes_and_types[0][", j, "].shape");
89             }
90           }
91         }
92         if (shapes_and_types) {
93           c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
94         }
95         return Status::OK();
96       }
97     });
98 
99 // --------------------------------------------------------------------------
100 
101 // Note that the following operator is just a placeholder and has no
102 // associated kernel. The code in accumulate_n_optimizer.cc replaces
103 // this placeholder with a graph of operators that do have kernels.
104 // The Python code that generates instances of this op is currently in
105 // contrib/framework/python/ops/accumulate_n_v2.py
106 REGISTER_OP("AccumulateNV2")
107     .Input("inputs: N * T")
108     .Output("sum: T")
109     .Attr("N: int >= 1")
110     .Attr("T: numbertype")
111     .Attr("shape: shape")
112     .SetIsCommutative()
113     .SetIsAggregate()
114     .SetShapeFn(shape_inference::ExplicitShape);
115 
116 // --------------------------------------------------------------------------
117 
118 REGISTER_OP("BatchMatMul")
119     .Input("x: T")
120     .Input("y: T")
121     .Output("output: T")
122     .Attr(
123         "T: {bfloat16, half, float, double, int32, int64, complex64, "
124         "complex128}")
125     .Attr("adj_x: bool = false")
126     .Attr("adj_y: bool = false")
__anonb22bfa860202(InferenceContext* c) 127     .SetShapeFn([](InferenceContext* c) {
128       ShapeHandle a_shape;
129       ShapeHandle b_shape;
130       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &a_shape));
131       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 2, &b_shape));
132 
133       // Determine output rows and cols.
134       bool adj_x;
135       bool adj_y;
136       TF_RETURN_IF_ERROR(c->GetAttr("adj_x", &adj_x));
137       TF_RETURN_IF_ERROR(c->GetAttr("adj_y", &adj_y));
138       DimensionHandle output_rows = c->Dim(a_shape, adj_x ? -1 : -2);
139       DimensionHandle output_cols = c->Dim(b_shape, adj_y ? -2 : -1);
140 
141       // Batch dims match between inputs.
142       ShapeHandle a_batch_dims;
143       ShapeHandle b_batch_dims;
144       ShapeHandle batch_dims;
145       TF_RETURN_IF_ERROR(c->Subshape(a_shape, 0, -2, &a_batch_dims));
146       TF_RETURN_IF_ERROR(c->Subshape(b_shape, 0, -2, &b_batch_dims));
147       TF_RETURN_IF_ERROR(c->Merge(a_batch_dims, b_batch_dims, &batch_dims));
148 
149       // Assert inner dims match.
150       DimensionHandle unused;
151       TF_RETURN_IF_ERROR(c->Merge(c->Dim(a_shape, adj_x ? -2 : -1),
152                                   c->Dim(b_shape, adj_y ? -1 : -2), &unused));
153 
154       ShapeHandle out;
155       TF_RETURN_IF_ERROR(c->Concatenate(
156           batch_dims, c->Matrix(output_rows, output_cols), &out));
157       c->set_output(0, out);
158       return Status::OK();
159     });
160 
161 // --------------------------------------------------------------------------
162 // Casting Ops
163 //
164 // NOTE: Only a smaller number of types are supported by
165 // Cast. The exact casting rule is TBD. The current
166 // implementation uses C++ static cast rules for numeric
167 // types, which may be changed in the future.
168 REGISTER_OP("Cast")
169     .Input("x: SrcT")
170     .Output("y: DstT")
171     .Attr("SrcT: type")
172     .Attr("DstT: type")
173     .Attr("Truncate: bool = false")
174     .SetShapeFn(shape_inference::UnchangedShape);
175 
176 REGISTER_OP("_HostCast")
177     .Input("x: SrcT")
178     .Output("y: DstT")
179     .Attr("SrcT: type")
180     .Attr("DstT: type")
181     .Attr("Truncate: bool = false")
182     .SetShapeFn(shape_inference::UnchangedShape)
183     .Doc(R"doc(
184 Cast x of type SrcT to y of DstT.
185 
186 _HostCast requires its input and produces its output in host memory.
187 )doc");
188 
189 // --------------------------------------------------------------------------
190 
191 REGISTER_OP("Abs")
192     .Input("x: T")
193     .Output("y: T")
194     .Attr("T: {bfloat16, half, float, double, int32, int64}")
195     .SetShapeFn(shape_inference::UnchangedShape);
196 
197 REGISTER_OP("ComplexAbs")
198     .Input("x: T")
199     .Output("y: Tout")
200     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
201     .Attr("Tout: {float, double} = DT_FLOAT")
202     .SetShapeFn(shape_inference::UnchangedShape);
203 
204 // Declares cwise unary operations signature: 't -> 't
205 #define UNARY()                                                          \
206   Input("x: T")                                                          \
207       .Output("y: T")                                                    \
208       .Attr(                                                             \
209           "T: {bfloat16, half, float, double, int32, int64, complex64, " \
210           "complex128}")                                                 \
211       .SetShapeFn(shape_inference::UnchangedShape)
212 
213 #define UNARY_REAL()                              \
214   Input("x: T")                                   \
215       .Output("y: T")                             \
216       .Attr("T: {bfloat16, half, float, double}") \
217       .SetShapeFn(shape_inference::UnchangedShape)
218 
219 #define UNARY_COMPLEX()                                                  \
220   Input("x: T")                                                          \
221       .Output("y: T")                                                    \
222       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
223       .SetShapeFn(shape_inference::UnchangedShape)
224 
225 #define UNARY_GRADIENT_COMPLEX()                                         \
226   Input("y: T")                                                          \
227       .Input("dy: T")                                                    \
228       .Output("z: T")                                                    \
229       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
230       .SetShapeFn(shape_inference::UnchangedShape)
231 
232 REGISTER_OP("Neg").UNARY();
233 
234 REGISTER_OP("Inv").UNARY();
235 
236 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
237 
238 REGISTER_OP("Reciprocal").UNARY();
239 
240 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
241 
242 REGISTER_OP("Square").UNARY();
243 
244 REGISTER_OP("Sqrt").UNARY_COMPLEX();
245 
246 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
247 
248 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
249 
250 REGISTER_OP("Round").UNARY();
251 
252 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
253 
254 REGISTER_OP("Exp").UNARY_COMPLEX();
255 
256 REGISTER_OP("Expm1").UNARY_COMPLEX();
257 
258 REGISTER_OP("Log").UNARY_COMPLEX();
259 
260 REGISTER_OP("Log1p").UNARY_COMPLEX();
261 
262 REGISTER_OP("Sinh").UNARY_COMPLEX();
263 
264 REGISTER_OP("Cosh").UNARY_COMPLEX();
265 
266 REGISTER_OP("Tanh").UNARY_COMPLEX();
267 
268 REGISTER_OP("Asinh").UNARY_COMPLEX();
269 
270 REGISTER_OP("Acosh").UNARY_COMPLEX();
271 
272 REGISTER_OP("Atanh").UNARY_COMPLEX();
273 
274 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
275 
276 REGISTER_OP("Lgamma").UNARY_REAL();
277 
278 REGISTER_OP("Digamma").UNARY_REAL();
279 
280 REGISTER_OP("Erf").UNARY_REAL();
281 
282 REGISTER_OP("Erfc").UNARY_REAL();
283 
284 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
285 
286 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
287 
288 REGISTER_OP("Sin").UNARY_COMPLEX();
289 
290 REGISTER_OP("Cos").UNARY_COMPLEX();
291 
292 REGISTER_OP("Tan").UNARY();
293 
294 REGISTER_OP("Asin").UNARY();
295 
296 REGISTER_OP("Acos").UNARY();
297 
298 REGISTER_OP("Atan").UNARY();
299 
300 REGISTER_OP("BesselI0e").UNARY_REAL();
301 
302 REGISTER_OP("BesselI1e").UNARY_REAL();
303 
304 REGISTER_OP("_UnaryOpsComposition")
305     .Input("x: T")
306     .Output("y: T")
307     .Attr("T: {float, half, double}")
308     .Attr("op_names: list(string)")
309     .SetShapeFn(shape_inference::UnchangedShape)
310     .Doc(R"doc(
311 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
312 expected to create these operators.
313 )doc");
314 
315 #undef UNARY
316 #undef UNARY_REAL
317 #undef UNARY_COMPLEX
318 
319 REGISTER_OP("IsNan")
320     .Input("x: T")
321     .Output("y: bool")
322     .Attr("T: {bfloat16, half, float, double}")
323     .SetShapeFn(shape_inference::UnchangedShape);
324 
325 REGISTER_OP("IsInf")
326     .Input("x: T")
327     .Output("y: bool")
328     .Attr("T: {bfloat16, half, float, double}")
329     .SetShapeFn(shape_inference::UnchangedShape);
330 
331 REGISTER_OP("IsFinite")
332     .Input("x: T")
333     .Output("y: bool")
334     .Attr("T: {bfloat16, half, float, double}")
335     .SetShapeFn(shape_inference::UnchangedShape);
336 
337 REGISTER_OP("Sign")
338     .Input("x: T")
339     .Output("y: T")
340     .Attr(
341         "T: {bfloat16, half, float, double, int32, int64, complex64, "
342         "complex128}")
343     .SetShapeFn(shape_inference::UnchangedShape);
344 
345 REGISTER_OP("Floor")
346     .Input("x: T")
347     .Output("y: T")
348     .Attr("T: {bfloat16, half, float, double}")
349     .SetShapeFn(shape_inference::UnchangedShape);
350 
351 REGISTER_OP("Ceil")
352     .Input("x: T")
353     .Output("y: T")
354     .Attr("T: {bfloat16, half, float, double}")
355     .SetShapeFn(shape_inference::UnchangedShape);
356 
357 REGISTER_OP("Rint")
358     .Input("x: T")
359     .Output("y: T")
360     .Attr("T: {bfloat16, half, float, double}")
361     .SetShapeFn(shape_inference::UnchangedShape);
362 
363 // Declares cwise binary operations signature: 't, 't -> 't.
364 
365 #define BINARY_MORE()                                                          \
366   Input("x: T").Input("y: T").Output("z: T").Attr(                             \
367       "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
368       "int64, complex64, complex128}")
369 
370 #define BINARY_FEWER()                                               \
371   Input("x: T").Input("y: T").Output("z: T").Attr(                   \
372       "T: {bfloat16, half, float, double, int32, int64, complex64, " \
373       "complex128}")
374 
375 REGISTER_OP("Add")
376     .Input("x: T")
377     .Input("y: T")
378     .Output("z: T")
379     .Attr(
380         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
381         "complex64, complex128, string}")
382     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
383 
384 // TODO(rmlarsen): Add a Python wrapper that swiches non-string instances to
385 // use AddV2 (b/68646025).
386 REGISTER_OP("AddV2")
387     .Input("x: T")
388     .Input("y: T")
389     .Output("z: T")
390     .Attr(
391         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
392         "complex64, complex128}")
393     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
394     .SetIsAggregate()
395     .SetIsCommutative();
396 
397 REGISTER_OP("_MklAdd")
398     .Input("x: T")
399     .Input("y: T")
400     .Input("mkl_x: uint8")
401     .Input("mkl_y: uint8")
402     .Output("z: T")
403     .Output("mkl_z: uint8")
404     .Attr(
405         "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
406         "complex128, string}")
407     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
408     .Doc(R"doc(
409 Returns `x` + `y` element-wise.
410 
411 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
412 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
413 )doc");
414 
415 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
416     shape_inference::BroadcastBinaryOpShapeFn);
417 
418 REGISTER_OP("_MklSub")
419     .BINARY_FEWER()
420     .Input("mkl_x: uint8")
421     .Input("mkl_y: uint8")
422     .Output("mkl_z: uint8")
423     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
424     .Doc(R"doc(
425 Returns x - y element-wise.
426 
427 *NOTE*: `Sub` supports broadcasting. More about broadcasting
428 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
429 )doc");
430 
431 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
432     shape_inference::BroadcastBinaryOpShapeFn);
433 
434 REGISTER_OP("MulNoNan")
435     .Input("x: T")
436     .Input("y: T")
437     .Output("z: T")
438     .Attr("T: {half, float, double, complex64, complex128}")
439     .SetIsCommutative()
440     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
441 
442 REGISTER_OP("_MklMul")
443     .BINARY_MORE()
444     .Input("mkl_x: uint8")
445     .Input("mkl_y: uint8")
446     .Output("mkl_z: uint8")
447     .SetIsCommutative()
448     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
449     .Doc(R"doc(
450 Returns x * y element-wise.
451 
452 *NOTE*: `Mul` supports broadcasting. More about broadcasting
453 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
454 )doc");
455 
456 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
457     shape_inference::BroadcastBinaryOpShapeFn);
458 
459 REGISTER_OP("DivNoNan")
460     .Input("x: T")
461     .Input("y: T")
462     .Output("z: T")
463     .Attr("T: {half, float, double, complex64, complex128}")
464     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
465 
466 REGISTER_OP("FloorDiv")
467     .BINARY_MORE()
468     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
469 
470 REGISTER_OP("TruncateDiv")
471     .BINARY_MORE()
472     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
473 
474 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
475     shape_inference::BroadcastBinaryOpShapeFn);
476 
477 REGISTER_OP("SquaredDifference")
478     .BINARY_FEWER()
479     .SetIsCommutative()
480     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
481 
482 REGISTER_OP("_MklSquaredDifference")
483     .BINARY_FEWER()
484     .Input("mkl_x: uint8")
485     .Input("mkl_y: uint8")
486     .Output("mkl_z: uint8")
487     .SetIsCommutative()
488     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
489     .Doc(R"doc(
490 Returns (x - y)(x - y) element-wise.
491 
492 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
493 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
494 )doc");
495 
496 REGISTER_OP("Xlogy")
497     .Input("x: T")
498     .Input("y: T")
499     .Output("z: T")
500     .Attr("T: {half, float, double, complex64, complex128}")
501     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
502 
503 REGISTER_OP("Xdivy")
504     .Input("x: T")
505     .Input("y: T")
506     .Output("z: T")
507     .Attr("T: {half, float, double, complex64, complex128}")
508     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
509 
510 #undef BINARY_FEWER
511 #undef BINARY_MORE
512 
513 REGISTER_OP("Maximum")
514     .Input("x: T")
515     .Input("y: T")
516     .Output("z: T")
517     .Attr("T: {bfloat16, half, float, double, int32, int64}")
518     .SetIsCommutative()
519     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
520 
521 REGISTER_OP("_MklMaximum")
522     .Input("x: T")
523     .Input("y: T")
524     .Input("mkl_x: uint8")
525     .Input("mkl_y: uint8")
526     .Output("z: T")
527     .Output("mkl_z: uint8")
528     .Attr("T: {half, float, double, int32, int64}")
529     .SetIsCommutative()
530     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
531     .Doc(R"doc(
532 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
533 
534 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
535 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
536 )doc");
537 
538 REGISTER_OP("Minimum")
539     .Input("x: T")
540     .Input("y: T")
541     .Output("z: T")
542     .Attr("T: {bfloat16, half, float, double, int32, int64}")
543     .SetIsCommutative()
544     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
545 
546 REGISTER_OP("Mod")
547     .Input("x: T")
548     .Input("y: T")
549     .Output("z: T")
550     .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
551     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
552 
553 REGISTER_OP("FloorMod")
554     .Input("x: T")
555     .Input("y: T")
556     .Output("z: T")
557     .Attr("T: {int32, int64, bfloat16, half, float, double}")
558     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
559 
560 REGISTER_OP("TruncateMod")
561     .Input("x: T")
562     .Input("y: T")
563     .Output("z: T")
564     .Attr("T: {int32, int64, bfloat16, half, float, double}")
565     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
566 
567 REGISTER_OP("Pow")
568     .Input("x: T")
569     .Input("y: T")
570     .Output("z: T")
571     .Attr(
572         "T: {bfloat16, float, half, double, int32, int64, complex64, "
573         "complex128}")
574     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
575 
576 REGISTER_OP("Igammac")
577     .Input("a: T")
578     .Input("x: T")
579     .Output("z: T")
580     .Attr("T: {float, double}")
581     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
582 
583 REGISTER_OP("Igamma")
584     .Input("a: T")
585     .Input("x: T")
586     .Output("z: T")
587     .Attr("T: {float, double}")
588     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
589 
590 REGISTER_OP("IgammaGradA")
591     .Input("a: T")
592     .Input("x: T")
593     .Output("z: T")
594     .Attr("T: {float, double}")
595     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
596 
597 REGISTER_OP("Zeta")
598     .Input("x: T")
599     .Input("q: T")
600     .Output("z: T")
601     .Attr("T: {float, double}")
602     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
603 
604 REGISTER_OP("Polygamma")
605     .Input("a: T")
606     .Input("x: T")
607     .Output("z: T")
608     .Attr("T: {float, double}")
609     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
610 
611 REGISTER_OP("Atan2")
612     .Input("y: T")
613     .Input("x: T")
614     .Output("z: T")
615     .Attr("T: {bfloat16, half, float, double}")
616     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
617 
618 REGISTER_OP("Betainc")
619     .Input("a: T")
620     .Input("b: T")
621     .Input("x: T")
622     .Output("z: T")
623     .Attr("T: {float, double}")
__anonb22bfa860302(InferenceContext* c) 624     .SetShapeFn([](InferenceContext* c) {
625       const int num_inputs = 3;
626       ShapeHandle output = c->UnknownShape();
627       int num_scalars = 0;
628       ShapeHandle some_non_scalar;
629       for (int i = 0; i < num_inputs; ++i) {
630         ShapeHandle in = c->input(i);
631         if (!c->RankKnown(in)) {
632           some_non_scalar = in;
633           // An input with unknown rank could be either a scalar (to be
634           // broadcast) or some other shape.
635         } else if (c->Rank(in) == 0) {
636           // Input is a scalar, it will be broadcast to the output shape.
637           ++num_scalars;
638         } else {
639           TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
640           some_non_scalar = output;
641         }
642       }
643 
644       if (num_scalars == num_inputs - 1) {
645         // If all but one input is known to be a scalar, then output is the
646         // remaining input.
647         output = some_non_scalar;
648       } else if (num_scalars == num_inputs) {
649         // If all are scalars, output is scalar; pick the first one arbitrarily.
650         output = c->input(0);
651       }
652 
653       c->set_output(0, output);
654       return Status::OK();
655     });
656 
657 // --------------------------------------------------------------------------
658 
659 // Declares cwise binary comparison operations signature: 't, 't -> bool,
660 // where 't has a natural total order.
661 #define COMPARISON()             \
662   Input("x: T")                  \
663       .Input("y: T")             \
664       .Output("z: bool")         \
665       .Attr("T: realnumbertype") \
666       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
667 
668 REGISTER_OP("Less").COMPARISON();
669 
670 REGISTER_OP("LessEqual").COMPARISON();
671 
672 REGISTER_OP("Greater").COMPARISON();
673 
674 REGISTER_OP("GreaterEqual").COMPARISON();
675 
676 #undef COMPARISON
677 
678 // --------------------------------------------------------------------------
679 
680 #define EQUALITY_COMPARISON()                                              \
681   Input("x: T")                                                            \
682       .Input("y: T")                                                       \
683       .Output("z: bool")                                                   \
684       .SetIsCommutative()                                                  \
685       .Attr(                                                               \
686           "T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
687           "int64, complex64, quint8, qint8, qint32, string, bool, "        \
688           "complex128}")                                                   \
689       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
690 
691 REGISTER_OP("Equal").EQUALITY_COMPARISON();
692 
693 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
694 
695 #undef EQUALITY_COMPARISON
696 
697 REGISTER_OP("ApproximateEqual")
698     .Input("x: T")
699     .Input("y: T")
700     .Output("z: bool")
701     .SetIsCommutative()
702     .Attr("T: numbertype")
703     .Attr("tolerance: float = 0.00001")
__anonb22bfa860402(InferenceContext* c) 704     .SetShapeFn([](InferenceContext* c) {
705       // The inputs 'x' and 'y' must have the same shape.
706       ShapeHandle data_x = c->input(0);
707       ShapeHandle data_y = c->input(1);
708       TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
709       return shape_inference::UnchangedShape(c);
710     });
711 
712 // --------------------------------------------------------------------------
713 
714 REGISTER_OP("LogicalNot")
715     .Input("x: bool")
716     .Output("y: bool")
717     .SetShapeFn(shape_inference::UnchangedShape);
718 
719 #define BINARY_LOGICAL()  \
720   Input("x: bool")        \
721       .Input("y: bool")   \
722       .Output("z: bool")  \
723       .SetIsCommutative() \
724       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
725 
726 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
727 
728 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
729 
730 #undef BINARY_LOGICAL
731 
732 // --------------------------------------------------------------------------
733 
734 REGISTER_OP("Select")
735     .Input("condition: bool")
736     .Input("t: T")
737     .Input("e: T")
738     .Output("output: T")
739     .Attr("T: type")
__anonb22bfa860502(InferenceContext* c) 740     .SetShapeFn([](InferenceContext* c) {
741       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
742       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
743       // Merge handle shape and dtype if applicable.
744       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
745         const auto size = handle_data_1->size();
746         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
747         if (size != handle_data_2->size()) {
748           return errors::InvalidArgument(
749               "Trying to merge handles pointing to different numbers of "
750               "tensors.");
751         }
752 
753         for (int i = 0; i < size; ++i) {
754           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
755           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
756           if (s1.dtype != s2.dtype) {
757             // TODO(apassos) resolve this in the manner of b/32476923
758             return errors::InvalidArgument(
759                 "Trying to merge handles pointing to different dtypes.");
760           }
761           merged_handle_data[i].dtype = s1.dtype;
762           TF_RETURN_IF_ERROR(
763               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
764         }
765 
766         c->set_output_handle_shapes_and_types(0, merged_handle_data);
767       }
768 
769       // The inputs 'then' and 'else' must have the same shape.
770       ShapeHandle data = c->input(1);
771       ShapeHandle other = c->input(2);
772       TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
773 
774       // The input 'cond' must either have the same shape as 'then' and
775       // 'else', or be a vector if 'then' and 'else' are at least vectors.
776       ShapeHandle cond = c->input(0);
777 
778       if (!c->RankKnown(cond) || !c->RankKnown(data)) {
779         c->set_output(0, data);
780         return Status::OK();
781       }
782 
783       // rank of shape and data is known.
784 
785       const int32 cond_rank = c->Rank(cond);
786       const int32 data_rank = c->Rank(data);
787 
788       if (cond_rank == 0) {
789         // The rank of 'cond' is a scalar.
790         // t and e can have any shape.
791         c->set_output(0, data);
792         return Status::OK();
793       }
794 
795       if (cond_rank != 1) {
796         // If 'cond' is not a vector, and not a scalar,
797         // then shape must match 'then' and 'else'
798         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
799         c->set_output(0, data);
800         return Status::OK();
801       }
802 
803       if (data_rank == 0) {
804         // if 'then' and 'else' are scalar also the cond must be
805         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
806         c->set_output(0, data);
807         return Status::OK();
808       }
809 
810       if (cond_rank == 1) {
811         // if the cond is a vector and the 'then' is not a scalar,
812         // the first dimension of 'then' and 'else'
813         TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
814         c->set_output(0, data);
815         return Status::OK();
816       }
817 
818       c->set_output(0, data);
819 
820       return Status::OK();
821     });
822 
823 // --------------------------------------------------------------------------
824 
825 REGISTER_OP("MatMul")
826     .Input("a: T")
827     .Input("b: T")
828     .Output("product: T")
829     .Attr("transpose_a: bool = false")
830     .Attr("transpose_b: bool = false")
831     .Attr(
832         "T: {bfloat16, half, float, double, int32, int64, complex64, "
833         "complex128}")
834     .SetShapeFn(shape_inference::MatMulShape);
835 
836 REGISTER_OP("SparseMatMul")
837     .Input("a: Ta")
838     .Input("b: Tb")
839     .Output("product: float")
840     .Attr("transpose_a: bool = false")
841     .Attr("transpose_b: bool = false")
842     .Attr("a_is_sparse: bool = false")
843     .Attr("b_is_sparse: bool = false")
844     .Attr("Ta: {float, bfloat16} = DT_FLOAT")
845     .Attr("Tb: {float, bfloat16} = DT_FLOAT")
846     .SetShapeFn(shape_inference::MatMulShape);
847 
848 // --------------------------------------------------------------------------
849 
850 // For operations where the output is a reduction function along some
851 // dimensions of the input.
852 REGISTER_OP("Sum")
853     .Input("input: T")
854     .Input("reduction_indices: Tidx")
855     .Output("output: T")
856     .Attr("keep_dims: bool = false")
857     .Attr("T: numbertype")
858     .Attr("Tidx: {int32, int64} = DT_INT32")
859     .SetShapeFn(shape_inference::ReductionShape);
860 
861 REGISTER_OP("EuclideanNorm")
862     .Input("input: T")
863     .Input("reduction_indices: Tidx")
864     .Output("output: T")
865     .Attr("keep_dims: bool = false")
866     .Attr("T: numbertype")
867     .Attr("Tidx: {int32, int64} = DT_INT32")
868     .SetShapeFn(shape_inference::ReductionShape);
869 
870 REGISTER_OP("Mean")
871     .Input("input: T")
872     .Input("reduction_indices: Tidx")
873     .Output("output: T")
874     .Attr("keep_dims: bool = false")
875     .Attr("T: numbertype")
876     .Attr("Tidx: {int32, int64} = DT_INT32")
877     .SetShapeFn(shape_inference::ReductionShape);
878 
879 REGISTER_OP("Prod")
880     .Input("input: T")
881     .Input("reduction_indices: Tidx")
882     .Output("output: T")
883     .Attr("keep_dims: bool = false")
884     .Attr("T: numbertype")
885     .Attr("Tidx: {int32, int64} = DT_INT32")
886     .SetShapeFn(shape_inference::ReductionShape);
887 
888 REGISTER_OP("Min")
889     .Input("input: T")
890     .Input("reduction_indices: Tidx")
891     .Output("output: T")
892     .Attr("keep_dims: bool = false")
893     .Attr("T: numbertype")
894     .Attr("Tidx: {int32, int64} = DT_INT32")
895     .SetShapeFn(shape_inference::ReductionShape);
896 
897 REGISTER_OP("Max")
898     .Input("input: T")
899     .Input("reduction_indices: Tidx")
900     .Output("output: T")
901     .Attr("keep_dims: bool = false")
902     .Attr("T: numbertype")
903     .Attr("Tidx: {int32, int64} = DT_INT32")
904     .SetShapeFn(shape_inference::ReductionShape);
905 
906 namespace {
907 
ArgOpShape(shape_inference::InferenceContext * c)908 Status ArgOpShape(shape_inference::InferenceContext* c) {
909   ShapeHandle dimension_shape;
910   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
911 
912   ShapeHandle input_shape = c->input(0);
913   if (!c->RankKnown(input_shape)) {
914     return shape_inference::UnknownShape(c);
915   }
916 
917   const int32 input_rank = c->Rank(input_shape);
918   if (input_rank <= 1) {
919     // Reducing a scalar/vector must return a scalar.
920     return shape_inference::ScalarShape(c);
921   }
922 
923   const Tensor* dim_t = c->input_tensor(1);
924   if (dim_t == nullptr) {
925     // We don't know the value of the dimension, but we
926     // know the rank of the input, so return the correct
927     // rank with unknown dimensions.
928     std::vector<DimensionHandle> dims(input_rank - 1);
929     for (int i = 0; i < dims.size(); ++i) {
930       dims[i] = c->UnknownDim();
931     }
932 
933     c->set_output(0, c->MakeShape(dims));
934     return Status::OK();
935   }
936 
937   int64 dimension_val;
938   if (dim_t->dtype() == DT_INT32) {
939     dimension_val = dim_t->scalar<int32>()();
940   } else {
941     dimension_val = dim_t->scalar<int64>()();
942   }
943 
944   int64 axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
945   if (axis < 0 || axis >= input_rank) {
946     return errors::InvalidArgument(
947         "Dimension (", dimension_val, ") must be in the range [", -input_rank,
948         ", ", input_rank, "), where ", input_rank,
949         " is the number of dimensions in the input.");
950   }
951 
952   // Return the input shape without the dimension being reduced.
953   std::vector<DimensionHandle> dims;
954   for (int i = 0; i < input_rank; ++i) {
955     if (axis != i) {
956       dims.emplace_back(c->Dim(input_shape, i));
957     }
958   }
959   c->set_output(0, c->MakeShape(dims));
960   return Status::OK();
961 }
962 
963 }  // namespace
964 
965 REGISTER_OP("ArgMax")
966     .Input("input: T")
967     .Input("dimension: Tidx")
968     .Output("output: output_type")
969     .Attr("T: numbertype")
970     .Attr("Tidx: {int32, int64} = DT_INT32")
971     .Attr("output_type: {int32, int64} = DT_INT64")
972     .SetShapeFn(ArgOpShape);
973 
974 REGISTER_OP("ArgMin")
975     .Input("input: T")
976     .Input("dimension: Tidx")
977     .Output("output: output_type")
978     .Attr("T: numbertype")
979     .Attr("Tidx: {int32, int64} = DT_INT32")
980     .Attr("output_type: {int32, int64} = DT_INT64")
981     .SetShapeFn(ArgOpShape);
982 
983 namespace {
984 
SegmentReductionShapeFn(InferenceContext * c)985 Status SegmentReductionShapeFn(InferenceContext* c) {
986   ShapeHandle data_shape;
987   ShapeHandle segment_ids_shape;
988   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
989   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
990 
991   ShapeHandle subshape;
992   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
993 
994   ShapeHandle out;
995   TF_RETURN_IF_ERROR(
996       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
997   c->set_output(0, out);
998   return Status::OK();
999 }
1000 
SparseSegmentReductionShapeFn(InferenceContext * c)1001 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
1002   ShapeHandle data_shape;
1003   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1004 
1005   ShapeHandle indices_shape;
1006   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1007 
1008   ShapeHandle segment_ids_shape;
1009   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1010 
1011   // indices and segment_ids should merge cleanly.
1012   ShapeHandle unused;
1013   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1014 
1015   ShapeHandle subshape;
1016   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1017 
1018   ShapeHandle out;
1019   TF_RETURN_IF_ERROR(
1020       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1021   c->set_output(0, out);
1022   return Status::OK();
1023 }
1024 
SparseSegmentReductionGradShapeFn(InferenceContext * c)1025 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
1026   ShapeHandle data_shape;
1027   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1028 
1029   ShapeHandle indices_shape;
1030   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1031 
1032   // indices and segment_ids should merge cleanly.
1033   ShapeHandle unused;
1034   TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
1035 
1036   // output_dim0 should be a scalar
1037   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1038 
1039   ShapeHandle subshape;
1040   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1041 
1042   const Tensor* dim0 = c->input_tensor(3);
1043   ShapeHandle dim0_shape;
1044   if (dim0 == nullptr) {
1045     // We don't have the value at inference time, so the output
1046     // shape is unknown.
1047     dim0_shape = c->Vector(InferenceContext::kUnknownDim);
1048   } else {
1049     auto dim0_value = dim0->scalar<int32>()();
1050     if (dim0_value < 0) {
1051       return errors::InvalidArgument(
1052           "Cannot specify a negative value for output_dim0");
1053     }
1054     dim0_shape = c->Vector(dim0_value);
1055   }
1056 
1057   ShapeHandle out;
1058   TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
1059   c->set_output(0, out);
1060   return Status::OK();
1061 }
1062 
SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext * c)1063 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
1064   ShapeHandle data_shape;
1065   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1066 
1067   ShapeHandle indices_shape;
1068   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1069 
1070   ShapeHandle segment_ids_shape;
1071   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1072 
1073   ShapeHandle num_segments_shape;
1074   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
1075 
1076   // indices and segment_ids should merge cleanly.
1077   ShapeHandle unused;
1078   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1079 
1080   ShapeHandle subshape;
1081   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1082 
1083   ShapeHandle out;
1084   const Tensor* dim0 = c->input_tensor(3);
1085   if (dim0 == nullptr) {
1086     // We don't have the value at inference time, so the output
1087     // shape is unknown.
1088     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
1089                                       subshape, &out));
1090   } else {
1091     auto dim0_value = dim0->scalar<int32>()();
1092     if (dim0_value < 0) {
1093       return errors::InvalidArgument(
1094           "Cannot specify a negative value for num_segments");
1095     }
1096     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
1097   }
1098   c->set_output(0, out);
1099   return Status::OK();
1100 }
1101 
UnsortedSegmentReductionShapeFn(InferenceContext * c)1102 Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
1103   ShapeHandle s_data = c->input(0);
1104   ShapeHandle s_segment_ids = c->input(1);
1105   ShapeHandle s_num_segments = c->input(2);
1106   TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
1107 
1108   ShapeHandle out;
1109 
1110   // Leading dimensions of data must be compatible with dimensions of
1111   // <s_segment_ids>.
1112   if (c->RankKnown(s_segment_ids)) {
1113     TF_RETURN_IF_ERROR(
1114         c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
1115 
1116     // Get the value of the num_segments input tensor.
1117     DimensionHandle num_segments_dim;
1118     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
1119 
1120     // Output is {segment_id_rank} + s_data[segment_id_rank:].
1121     ShapeHandle s_data_suffix;
1122     TF_RETURN_IF_ERROR(
1123         c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
1124     TF_RETURN_IF_ERROR(
1125         c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
1126   } else {
1127     out = c->UnknownShape();
1128   }
1129   c->set_output(0, out);
1130   return Status::OK();
1131 }
1132 }  // namespace
1133 
1134 REGISTER_OP("SegmentSum")
1135     .Input("data: T")
1136     .Input("segment_ids: Tindices")
1137     .Output("output: T")
1138     .Attr("T: numbertype")
1139     .Attr("Tindices: {int32,int64}")
1140     .SetShapeFn(SegmentReductionShapeFn);
1141 
1142 REGISTER_OP("SegmentMean")
1143     .Input("data: T")
1144     .Input("segment_ids: Tindices")
1145     .Output("output: T")
1146     .Attr("T: numbertype")
1147     .Attr("Tindices: {int32,int64}")
1148     .SetShapeFn(SegmentReductionShapeFn);
1149 
1150 REGISTER_OP("SegmentProd")
1151     .Input("data: T")
1152     .Input("segment_ids: Tindices")
1153     .Output("output: T")
1154     .Attr("T: numbertype")
1155     .Attr("Tindices: {int32,int64}")
1156     .SetShapeFn(SegmentReductionShapeFn);
1157 
1158 REGISTER_OP("SegmentMin")
1159     .Input("data: T")
1160     .Input("segment_ids: Tindices")
1161     .Output("output: T")
1162     .Attr("T: realnumbertype")
1163     .Attr("Tindices: {int32,int64}")
1164     .SetShapeFn(SegmentReductionShapeFn);
1165 
1166 REGISTER_OP("SegmentMax")
1167     .Input("data: T")
1168     .Input("segment_ids: Tindices")
1169     .Output("output: T")
1170     .Attr("T: realnumbertype")
1171     .Attr("Tindices: {int32,int64}")
1172     .SetShapeFn(SegmentReductionShapeFn);
1173 
1174 REGISTER_OP("UnsortedSegmentSum")
1175     .Input("data: T")
1176     .Input("segment_ids: Tindices")
1177     .Input("num_segments: Tnumsegments")
1178     .Output("output: T")
1179     .Attr("T: numbertype")
1180     .Attr("Tindices: {int32,int64}")
1181     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1182     .SetShapeFn(UnsortedSegmentReductionShapeFn);
1183 
1184 REGISTER_OP("UnsortedSegmentMax")
1185     .Input("data: T")
1186     .Input("segment_ids: Tindices")
1187     .Input("num_segments: Tnumsegments")
1188     .Output("output: T")
1189     .Attr("T: realnumbertype")
1190     .Attr("Tindices: {int32,int64}")
1191     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1192     .SetShapeFn(UnsortedSegmentReductionShapeFn);
1193 
1194 REGISTER_OP("UnsortedSegmentMin")
1195     .Input("data: T")
1196     .Input("segment_ids: Tindices")
1197     .Input("num_segments: Tnumsegments")
1198     .Output("output: T")
1199     .Attr("T: realnumbertype")
1200     .Attr("Tindices: {int32,int64}")
1201     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1202     .SetShapeFn(UnsortedSegmentReductionShapeFn);
1203 
1204 REGISTER_OP("UnsortedSegmentProd")
1205     .Input("data: T")
1206     .Input("segment_ids: Tindices")
1207     .Input("num_segments: Tnumsegments")
1208     .Output("output: T")
1209     .Attr("T: numbertype")
1210     .Attr("Tindices: {int32,int64}")
1211     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1212     .SetShapeFn(UnsortedSegmentReductionShapeFn);
1213 
1214 REGISTER_OP("SparseSegmentSum")
1215     .Input("data: T")
1216     .Input("indices: Tidx")
1217     .Input("segment_ids: int32")
1218     .Output("output: T")
1219     .Attr("T: realnumbertype")
1220     .Attr("Tidx: {int32, int64} = DT_INT32")
1221     .SetShapeFn(SparseSegmentReductionShapeFn);
1222 
1223 REGISTER_OP("SparseSegmentSumWithNumSegments")
1224     .Input("data: T")
1225     .Input("indices: Tidx")
1226     .Input("segment_ids: int32")
1227     .Input("num_segments: Tnumsegments")
1228     .Output("output: T")
1229     .Attr("T: realnumbertype")
1230     .Attr("Tidx: {int32, int64} = DT_INT32")
1231     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1232     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1233 
1234 REGISTER_OP("SparseSegmentMean")
1235     .Input("data: T")
1236     .Input("indices: Tidx")
1237     .Input("segment_ids: int32")
1238     .Output("output: T")
1239     .Attr("T: {float, double}")
1240     .Attr("Tidx: {int32, int64} = DT_INT32")
1241     .SetShapeFn(SparseSegmentReductionShapeFn);
1242 
1243 REGISTER_OP("SparseSegmentMeanWithNumSegments")
1244     .Input("data: T")
1245     .Input("indices: Tidx")
1246     .Input("segment_ids: int32")
1247     .Input("num_segments: Tnumsegments")
1248     .Output("output: T")
1249     .Attr("T: {float, double}")
1250     .Attr("Tidx: {int32, int64} = DT_INT32")
1251     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1252     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1253 
1254 REGISTER_OP("SparseSegmentMeanGrad")
1255     .Input("grad: T")
1256     .Input("indices: Tidx")
1257     .Input("segment_ids: int32")
1258     .Input("output_dim0: int32")
1259     .Output("output: T")
1260     .Attr("T: {float, double}")
1261     .Attr("Tidx: {int32, int64} = DT_INT32")
1262     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1263 
1264 REGISTER_OP("SparseSegmentSqrtN")
1265     .Input("data: T")
1266     .Input("indices: Tidx")
1267     .Input("segment_ids: int32")
1268     .Output("output: T")
1269     .Attr("T: {float, double}")
1270     .Attr("Tidx: {int32, int64} = DT_INT32")
1271     .SetShapeFn(SparseSegmentReductionShapeFn);
1272 
1273 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
1274     .Input("data: T")
1275     .Input("indices: Tidx")
1276     .Input("segment_ids: int32")
1277     .Input("num_segments: Tnumsegments")
1278     .Output("output: T")
1279     .Attr("T: {float, double}")
1280     .Attr("Tidx: {int32, int64} = DT_INT32")
1281     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1282     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1283 
1284 REGISTER_OP("SparseSegmentSqrtNGrad")
1285     .Input("grad: T")
1286     .Input("indices: Tidx")
1287     .Input("segment_ids: int32")
1288     .Input("output_dim0: int32")
1289     .Output("output: T")
1290     .Attr("T: {float, double}")
1291     .Attr("Tidx: {int32, int64} = DT_INT32")
1292     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1293 
1294 REGISTER_OP("All")
1295     .Input("input: bool")
1296     .Input("reduction_indices: Tidx")
1297     .Output("output: bool")
1298     .Attr("keep_dims: bool = false")
1299     .Attr("Tidx: {int32, int64} = DT_INT32")
1300     .SetShapeFn(shape_inference::ReductionShape);
1301 
1302 REGISTER_OP("Any")
1303     .Input("input: bool")
1304     .Input("reduction_indices: Tidx")
1305     .Attr("keep_dims: bool = false")
1306     .Output("output: bool")
1307     .Attr("Tidx: {int32, int64} = DT_INT32")
1308     .SetShapeFn(shape_inference::ReductionShape);
1309 
1310 // --------------------------------------------------------------------------
1311 
1312 namespace {
1313 
1314 template <typename T>
RangeSize(const Tensor * start_t,const Tensor * limit_t,const Tensor * delta_t,InferenceContext * const c)1315 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
1316                  const Tensor* delta_t, InferenceContext* const c) {
1317   T start = start_t->scalar<T>()();
1318   T limit = limit_t->scalar<T>()();
1319   T delta = delta_t->scalar<T>()();
1320   if (start > limit && delta > 0) {
1321     return errors::InvalidArgument(
1322         "Requires start <= limit when delta > 0: ", start, "/", limit);
1323   }
1324   if (start < limit && delta < 0) {
1325     return errors::InvalidArgument(
1326         "Requires start >= limit when delta < 0: ", start, "/", limit);
1327   }
1328   if (delta == 0) {
1329     return errors::InvalidArgument("Requires delta != 0");
1330   }
1331 
1332   int64 size =
1333       (std::is_integral<T>::value
1334            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
1335            : std::ceil(std::abs((limit - start) / delta)));
1336   c->set_output(0, c->Vector(size));
1337   return Status::OK();
1338 }
1339 
1340 }  // namespace
1341 
1342 REGISTER_OP("Range")
1343     .Input("start: Tidx")
1344     .Input("limit: Tidx")
1345     .Input("delta: Tidx")
1346     .Output("output: Tidx")
1347     .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32")
__anonb22bfa860902(InferenceContext* c) 1348     .SetShapeFn([](InferenceContext* c) {
1349       ShapeHandle unused;
1350       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1351                                       " for 'start'");
1352       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1353                                       " for 'limit'");
1354       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1355                                       " for 'delta'");
1356       const Tensor* start_t = c->input_tensor(0);
1357       const Tensor* limit_t = c->input_tensor(1);
1358       const Tensor* delta_t = c->input_tensor(2);
1359       DataType dtype;
1360       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1361       if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
1362         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1363         return Status::OK();
1364       }
1365       if (dtype == DT_INT32) {
1366         return RangeSize<int32>(start_t, limit_t, delta_t, c);
1367       } else if (dtype == DT_INT64) {
1368         return RangeSize<int64>(start_t, limit_t, delta_t, c);
1369       } else if (dtype == DT_FLOAT) {
1370         return RangeSize<float>(start_t, limit_t, delta_t, c);
1371       } else {
1372         return RangeSize<double>(start_t, limit_t, delta_t, c);
1373       }
1374       return Status::OK();
1375     });
1376 
1377 REGISTER_OP("LinSpace")
1378     .Input("start: T")
1379     .Input("stop: T")
1380     .Input("num: Tidx")
1381     .Output("output: T")
1382     .Attr("T: {bfloat16, float, double}")
1383     .Attr("Tidx: {int32, int64} = DT_INT32")
__anonb22bfa860a02(InferenceContext* c) 1384     .SetShapeFn([](InferenceContext* c) {
1385       ShapeHandle unused;
1386       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1387                                       " for 'start'");
1388       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1389                                       " for 'stop'");
1390       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1391                                       " for 'num'");
1392       const Tensor* num_t = c->input_tensor(2);
1393       if (num_t == nullptr) {
1394         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1395         return Status::OK();
1396       }
1397 
1398       int64 num;
1399       if (num_t->dtype() == DT_INT32) {
1400         num = num_t->scalar<int32>()();
1401       } else {
1402         num = num_t->scalar<int64>()();
1403       }
1404       if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
1405       c->set_output(0, c->Vector(num));
1406       return Status::OK();
1407     });
1408 
1409 REGISTER_OP("Complex")
1410     .Input("real: T")
1411     .Input("imag: T")
1412     .Output("out: Tout")
1413     .Attr("T: {float, double} = DT_FLOAT")
1414     .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
1415     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1416 
1417 REGISTER_OP("Real")
1418     .Input("input: T")
1419     .Output("output: Tout")
1420     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1421     .Attr("Tout: {float, double} = DT_FLOAT")
1422     .SetShapeFn(shape_inference::UnchangedShape);
1423 
1424 REGISTER_OP("Imag")
1425     .Input("input: T")
1426     .Output("output: Tout")
1427     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1428     .Attr("Tout: {float, double} = DT_FLOAT")
1429     .SetShapeFn(shape_inference::UnchangedShape);
1430 
1431 REGISTER_OP("Angle")
1432     .Input("input: T")
1433     .Output("output: Tout")
1434     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1435     .Attr("Tout: {float, double} = DT_FLOAT")
1436     .SetShapeFn(shape_inference::UnchangedShape);
1437 
1438 REGISTER_OP("Conj")
1439     .Input("input: T")
1440     .Output("output: T")
1441     .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
__anonb22bfa860b02(InferenceContext* c) 1442     .SetShapeFn([](InferenceContext* c) {
1443       c->set_output(0, c->input(0));
1444       auto* handle_data = c->input_handle_shapes_and_types(0);
1445       if (handle_data != nullptr) {
1446         c->set_output_handle_shapes_and_types(0, *handle_data);
1447       }
1448       return Status::OK();
1449     });
1450 
1451 // --------------------------------------------------------------------------
1452 
1453 REGISTER_OP("Cross")
1454     .Input("a: T")
1455     .Input("b: T")
1456     .Output("product: T")
1457     .Attr("T: realnumbertype")
__anonb22bfa860c02(InferenceContext* c) 1458     .SetShapeFn([](InferenceContext* c) {
1459       ShapeHandle a_shape;
1460       ShapeHandle b_shape;
1461       // * Input rank >= 1.
1462       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
1463       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
1464 
1465       // * Both inputs have the same shape.
1466       TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
1467 
1468       // * input_shape[-1] == 3.
1469       if (c->RankKnown(a_shape)) {
1470         int rank = c->Rank(a_shape);
1471         auto dim = c->Dim(a_shape, rank - 1);
1472         TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
1473       }
1474       c->set_output(0, a_shape);
1475       return Status::OK();
1476     });
1477 
1478 // --------------------------------------------------------------------------
1479 
1480 REGISTER_OP("HistogramFixedWidth")
1481     .Input("values: T")
1482     .Input("value_range: T")
1483     .Input("nbins: int32")
1484     .Output("out: dtype")
1485     .Attr("T: {int32, int64, float32, float64}")
1486     .Attr("dtype: {int32, int64} = DT_INT32")
__anonb22bfa860d02(InferenceContext* c) 1487     .SetShapeFn([](InferenceContext* c) {
1488       // value_range should be a vector.
1489       ShapeHandle value_range_shape;
1490       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
1491       // value_range should have two elements.
1492       DimensionHandle unused;
1493       TF_RETURN_IF_ERROR(
1494           c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
1495       // nbins should be a scalar.
1496       ShapeHandle nbins_shape;
1497       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
1498 
1499       // If nbins is available, set the shape from nbins.
1500       const Tensor* nbins_input = c->input_tensor(2);
1501       if (nbins_input != nullptr) {
1502         int64 nbins;
1503         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
1504         // nbins has to be positive.
1505         if (nbins <= 0) {
1506           return errors::InvalidArgument("Requires nbins > 0: ", nbins);
1507         }
1508         c->set_output(0, c->Vector(nbins));
1509       } else {
1510         c->set_output(0, c->UnknownShapeOfRank(1));
1511       }
1512       return Status::OK();
1513     });
1514 
1515 REGISTER_OP("Bincount")
1516     .Input("arr: int32")
1517     .Input("size: int32")
1518     .Input("weights: T")
1519     .Attr("T: {int32, int64, float32, float64}")
1520     .Output("bins: T")
__anonb22bfa860e02(InferenceContext* c) 1521     .SetShapeFn([](InferenceContext* c) {
1522       ShapeHandle unused;
1523       // The input `size` must be a scalar.
1524       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1525 
1526       const Tensor* size_tensor = c->input_tensor(1);
1527       if (size_tensor == nullptr) {
1528         // Return unknown shape if size is not known.
1529         c->set_output(0, c->UnknownShapeOfRank(1));
1530         return Status::OK();
1531       }
1532 
1533       // Return `[size]` shape if size is known.
1534       int32 size_val = size_tensor->scalar<int32>()();
1535       if (size_val < 0) {
1536         return errors::InvalidArgument("size (", size_val,
1537                                        ") must be non-negative");
1538       }
1539       c->set_output(0, c->MakeShape({size_val}));
1540       return Status::OK();
1541     });
1542 
1543 REGISTER_OP("Cumsum")
1544     .Input("x: T")
1545     .Input("axis: Tidx")
1546     .Attr("exclusive: bool = false")
1547     .Attr("reverse: bool = false")
1548     .Output("out: T")
1549     .Attr("T: numbertype")
1550     .Attr("Tidx: {int32, int64} = DT_INT32")
1551     .SetShapeFn(shape_inference::UnchangedShape);
1552 
1553 REGISTER_OP("Cumprod")
1554     .Input("x: T")
1555     .Input("axis: Tidx")
1556     .Attr("exclusive: bool = false")
1557     .Attr("reverse: bool = false")
1558     .Output("out: T")
1559     .Attr("T: numbertype")
1560     .Attr("Tidx: {int32, int64} = DT_INT32")
1561     .SetShapeFn(shape_inference::UnchangedShape);
1562 
1563 REGISTER_OP("QuantizedMatMul")
1564     .Input("a: T1")
1565     .Input("b: T2")
1566     .Input("min_a: float")
1567     .Input("max_a: float")
1568     .Input("min_b: float")
1569     .Input("max_b: float")
1570     .Output("out: Toutput")
1571     .Output("min_out: float")
1572     .Output("max_out: float")
1573     .Attr("T1: quantizedtype")
1574     .Attr("T2: quantizedtype")
1575     .Attr("Toutput: quantizedtype = DT_QINT32")
1576     .Attr("transpose_a: bool = false")
1577     .Attr("transpose_b: bool = false")
1578     .Attr("Tactivation: quantizedtype = DT_QUINT8")
__anonb22bfa860f02(InferenceContext* c) 1579     .SetShapeFn([](InferenceContext* c) {
1580       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1581       ShapeHandle unused;
1582       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1583       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1584       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1585       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1586 
1587       c->set_output(1, c->Scalar());
1588       c->set_output(2, c->Scalar());
1589       return Status::OK();
1590     });
1591 
1592 REGISTER_OP("QuantizedMul")
1593     .Input("x: T1")
1594     .Input("y: T2")
1595     .Input("min_x: float")
1596     .Input("max_x: float")
1597     .Input("min_y: float")
1598     .Input("max_y: float")
1599     .Output("z: Toutput")
1600     .Output("min_z: float")
1601     .Output("max_z: float")
1602     .Attr("T1: quantizedtype")
1603     .Attr("T2: quantizedtype")
1604     .Attr("Toutput: quantizedtype = DT_QINT32")
1605     .SetIsCommutative()
__anonb22bfa861002(InferenceContext* c) 1606     .SetShapeFn([](InferenceContext* c) {
1607       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1608       c->set_output(1, c->Scalar());
1609       c->set_output(2, c->Scalar());
1610       return Status::OK();
1611     });
1612 
1613 REGISTER_OP("QuantizedAdd")
1614     .Input("x: T1")
1615     .Input("y: T2")
1616     .Input("min_x: float")
1617     .Input("max_x: float")
1618     .Input("min_y: float")
1619     .Input("max_y: float")
1620     .Output("z: Toutput")
1621     .Output("min_z: float")
1622     .Output("max_z: float")
1623     .Attr("T1: quantizedtype")
1624     .Attr("T2: quantizedtype")
1625     .Attr("Toutput: quantizedtype = DT_QINT32")
1626     .SetIsCommutative()
__anonb22bfa861102(InferenceContext* c) 1627     .SetShapeFn([](InferenceContext* c) {
1628       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1629       // min_x, max_x, min_y, max_y should be scalar.
1630       ShapeHandle unused;
1631       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1632       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1633       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1634       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1635 
1636       c->set_output(1, c->Scalar());
1637       c->set_output(2, c->Scalar());
1638       return Status::OK();
1639     });
1640 
1641 REGISTER_OP("QuantizeDownAndShrinkRange")
1642     .Input("input: Tinput")
1643     .Input("input_min: float")
1644     .Input("input_max: float")
1645     .Output("output: out_type")
1646     .Output("output_min: float")
1647     .Output("output_max: float")
1648     .Attr("Tinput: quantizedtype")
1649     .Attr("out_type: quantizedtype")
__anonb22bfa861202(InferenceContext* c) 1650     .SetShapeFn([](InferenceContext* c) {
1651       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1652       ShapeHandle unused;
1653       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1654       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1655       c->set_output(1, c->Scalar());
1656       c->set_output(2, c->Scalar());
1657       return Status::OK();
1658     });
1659 
1660 REGISTER_OP("Requantize")
1661     .Input("input: Tinput")
1662     .Input("input_min: float")
1663     .Input("input_max: float")
1664     .Input("requested_output_min: float")
1665     .Input("requested_output_max: float")
1666     .Output("output: out_type")
1667     .Output("output_min: float")
1668     .Output("output_max: float")
1669     .Attr("Tinput: quantizedtype")
1670     .Attr("out_type: quantizedtype")
__anonb22bfa861302(InferenceContext* c) 1671     .SetShapeFn([](InferenceContext* c) {
1672       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1673       ShapeHandle unused;
1674       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1675       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1676       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1677       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1678       c->set_output(1, c->Scalar());
1679       c->set_output(2, c->Scalar());
1680       return Status::OK();
1681     });
1682 
1683 REGISTER_OP("CompareAndBitpack")
1684     .Input("input: T")
1685     .Input("threshold: T")
1686     .Output("output: uint8")
1687     .Attr("T: {bool, float16, float32, float64, int8, int16, int32, int64}")
__anonb22bfa861402(InferenceContext* c) 1688     .SetShapeFn([](InferenceContext* c) {
1689       ShapeHandle input;
1690       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1691       ShapeHandle unused;
1692       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1693       ShapeHandle output = input;
1694       if (c->RankKnown(input)) {
1695         int rank = c->Rank(input);
1696         auto inner_dim = c->Dim(input, rank - 1);
1697         DimensionHandle inferred_dim;
1698         TF_RETURN_IF_ERROR(c->Divide(inner_dim, 8,
1699                                      /* evenly_divisible */ true,
1700                                      &inferred_dim));
1701         TF_RETURN_IF_ERROR(
1702             c->ReplaceDim(output, rank - 1, inferred_dim, &output));
1703       }
1704       c->set_output(0, output);
1705 
1706       return Status::OK();
1707     });
1708 
1709 REGISTER_OP("RequantizationRange")
1710     .Input("input: Tinput")
1711     .Input("input_min: float")
1712     .Input("input_max: float")
1713     .Output("output_min: float")
1714     .Output("output_max: float")
1715     .Attr("Tinput: quantizedtype")
__anonb22bfa861502(InferenceContext* c) 1716     .SetShapeFn([](InferenceContext* c) {
1717       ShapeHandle unused;
1718       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1719       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1720       c->set_output(0, c->Scalar());
1721       c->set_output(1, c->Scalar());
1722       return Status::OK();
1723     });
1724 
1725 // --------------------------------------------------------------------------
1726 
1727 REGISTER_OP("Bucketize")
1728     .Input("input: T")
1729     .Output("output: int32")
1730     .Attr("T: {int32, int64, float, double}")
1731     .Attr("boundaries: list(float)")
1732     .SetShapeFn(shape_inference::UnchangedShape);
1733 
1734 REGISTER_OP("ClipByValue")
1735     .Input("t: T")
1736     .Input("clip_value_min: T")
1737     .Input("clip_value_max: T")
1738     .Output("output: T")
1739     .Attr("T: numbertype")
1740     .SetShapeFn(shape_inference::UnchangedShape);
1741 
1742 #ifdef INTEL_MKL
1743 REGISTER_OP("_MklAddN")
1744     .Input("inputs: N * T")
1745     .Input("mkl_input: N * uint8")
1746     .Output("sum: T")
1747     .Output("mkl_sum: uint8")
1748     .Attr("N: int >= 1")
1749     .Attr("T: numbertype")
1750     .SetIsCommutative()
1751     .SetIsAggregate()
__anonb22bfa861602(InferenceContext* c) 1752     .SetShapeFn([](InferenceContext* c) {
1753       ShapeHandle cur = c->input(c->num_inputs() - 1);
1754       for (int i = c->num_inputs() - 2; i >= 0; --i) {
1755         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
1756                                         "From merging shape ", i,
1757                                         " with other shapes.");
1758       }
1759       c->set_output(0, cur);
1760       return Status::OK();
1761     })
1762     .Doc(R"doc(
1763 Add two input tensors element wise using mkl kernel sum.
1764 inputs: Must all be the same size and shape.
1765 )doc");
1766 
1767 #endif  // INTEL_MKL
1768 
1769 REGISTER_OP("RequantizePerChannel")
1770     .Input("input: T")
1771     .Input("input_min: float")
1772     .Input("input_max: float")
1773     .Input("requested_output_min: float")
1774     .Input("requested_output_max: float")
1775     .Output("output: out_type")
1776     .Output("output_min: float")
1777     .Output("output_max: float")
1778     .Attr("T: quantizedtype = DT_QINT32")
1779     .Attr("out_type: quantizedtype = DT_QUINT8")
__anonb22bfa861702(InferenceContext* c) 1780     .SetShapeFn([](InferenceContext* c) {
1781       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1782       ShapeHandle unused;
1783       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1784       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1785       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1786       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1787       c->set_output(1, c->Scalar());
1788       c->set_output(2, c->Scalar());
1789       return Status::OK();
1790     });
1791 REGISTER_OP("RequantizationRangePerChannel")
1792     .Input("input: T")
1793     .Input("input_min: float")
1794     .Input("input_max: float")
1795     .Output("output_min: float")
1796     .Output("output_max: float")
1797     .Attr("T: quantizedtype = DT_QINT32")
1798     .Attr("clip_value_max: float")
__anonb22bfa861802(InferenceContext* c) 1799     .SetShapeFn([](InferenceContext* c) {
1800       ShapeHandle unused;
1801       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
1802       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
1803       c->set_output(0, c->Scalar());
1804       c->set_output(1, c->Scalar());
1805       return Status::OK();
1806     });
1807 
1808 REGISTER_OP("NextAfter")
1809     .Attr("T: {float64, float32} = DT_FLOAT")
1810     .Input("x1: T")
1811     .Input("x2: T")
1812     .Output("output: T")
1813     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1814 
1815 }  // namespace tensorflow
1816