• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 
21 namespace tensorflow {
22 
23 using shape_inference::DimensionHandle;
24 using shape_inference::InferenceContext;
25 using shape_inference::ShapeHandle;
26 
27 REGISTER_OP("AddN")
28     .Input("inputs: N * T")
29     .Output("sum: T")
30     .Attr("N: int >= 1")
31     .Attr("T: {numbertype, variant}")
32     .SetIsCommutative()
33     .SetIsAggregate()
__anon99f049890102(InferenceContext* c) 34     .SetShapeFn([](InferenceContext* c) {
35       ShapeHandle cur = c->input(c->num_inputs() - 1);
36       for (int i = c->num_inputs() - 2; i >= 0; --i) {
37         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
38                                         "From merging shape ", i,
39                                         " with other shapes.");
40       }
41       c->set_output(0, cur);
42 
43       DataType dtype;
44       TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
45 
46       if (dtype != DT_VARIANT) {
47         // Exit early if not DT_VARIANT.
48         return Status::OK();
49       } else {
50         // DT_VARIANT shape handle shape inference.  All sizes and dtypes must
51         // be the same; all shapes must be compatible via Merge.
52         std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
53         auto* shapes_and_types =
54             c->input_handle_shapes_and_types(c->num_inputs() - 1);
55         if (shapes_and_types) {
56           cur_shapes_and_types = *shapes_and_types;
57         }
58 
59         for (int i = c->num_inputs() - 2; i >= 0; --i) {
60           auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
61           if (!shapes_and_types && shapes_and_types_i) {
62             // TODO(ebrevdo): Find cases where this happens and fix their shape
63             // inference.  If we are calling AddN on variant types, they should
64             // all have consistent shape_and_type info.
65             shapes_and_types = shapes_and_types_i;
66           } else if (shapes_and_types && shapes_and_types_i) {
67             if (shapes_and_types_i->size() != shapes_and_types->size()) {
68               return errors::InvalidArgument(
69                   "shapes_and_types[", i,
70                   "].size() == ", shapes_and_types_i->size(),
71                   " != shapes_and_types[0].size() == ",
72                   shapes_and_types->size());
73             }
74             for (int j = 0; j < shapes_and_types->size(); ++j) {
75               if (shapes_and_types->at(j).dtype !=
76                   shapes_and_types_i->at(j).dtype) {
77                 return errors::InvalidArgument(
78                     "shapes_and_types[", i, "][", j, "].dtype() == ",
79                     DataTypeString(shapes_and_types_i->at(j).dtype),
80                     " != shapes_and_types[0][", j, "].dtype == ",
81                     DataTypeString(shapes_and_types->at(j).dtype));
82               }
83               TF_RETURN_WITH_CONTEXT_IF_ERROR(
84                   c->Merge(shapes_and_types_i->at(j).shape,
85                            cur_shapes_and_types.at(j).shape,
86                            &cur_shapes_and_types.at(j).shape),
87                   "From merging shapes_and_types[", i, "][", j, "].shape with ",
88                   "shapes_and_types[0][", j, "].shape");
89             }
90           }
91         }
92         if (shapes_and_types) {
93           c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
94         }
95         return Status::OK();
96       }
97     });
98 
99 // --------------------------------------------------------------------------
100 
101 // Note that the following operator is just a placeholder and has no
102 // associated kernel. The code in accumulate_n_optimizer.cc replaces
103 // this placeholder with a graph of operators that do have kernels.
104 // The Python code that generates instances of this op is currently in
105 // contrib/framework/python/ops/accumulate_n_v2.py
106 REGISTER_OP("AccumulateNV2")
107     .Input("inputs: N * T")
108     .Output("sum: T")
109     .Attr("N: int >= 1")
110     .Attr("T: numbertype")
111     .Attr("shape: shape")
112     .SetIsCommutative()
113     .SetIsAggregate()
114     .SetShapeFn(shape_inference::ExplicitShape);
115 
116 // --------------------------------------------------------------------------
117 
118 REGISTER_OP("BatchMatMul")
119     .Input("x: T")
120     .Input("y: T")
121     .Output("output: T")
122     .Attr(
123         "T: {bfloat16, half, float, double, int32, int64, complex64, "
124         "complex128}")
125     .Attr("adj_x: bool = false")
126     .Attr("adj_y: bool = false")
127     .SetShapeFn(shape_inference::BatchMatMulShape);
128 
129 REGISTER_OP("BatchMatMulV2")
130     .Input("x: T")
131     .Input("y: T")
132     .Output("output: T")
133     .Attr(
134         "T: {bfloat16, half, float, double, int16, int32, int64, complex64, "
135         "complex128}")
136     .Attr("adj_x: bool = false")
137     .Attr("adj_y: bool = false")
138     .SetShapeFn(shape_inference::BatchMatMulV2Shape);
139 
140 REGISTER_OP("BatchMatMulV3")
141     .Input("x: Ta")
142     .Input("y: Tb")
143     .Output("output: Tout")
144     .Attr(
145         "Ta: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
146         "complex64, complex128}")
147     .Attr(
148         "Tb: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
149         "complex64, complex128}")
150     .Attr(
151         "Tout: {bfloat16, half, float, double, int16, int32, int64, complex64, "
152         "complex128}")
153     .Attr("adj_x: bool = false")
154     .Attr("adj_y: bool = false")
155     .SetShapeFn(shape_inference::BatchMatMulV2Shape);
156 
157 #ifdef INTEL_MKL
158 REGISTER_OP("_MklBatchMatMul")
159     .Input("x: T")
160     .Input("y: T")
161     .Output("output: T")
162     .Attr("T: {bfloat16, float}")
163     .Attr("adj_x: bool = false")
164     .Attr("adj_y: bool = false")
165     .SetShapeFn(shape_inference::BatchMatMulShape);
166 
167 REGISTER_OP("_MklBatchMatMulV2")
168     .Input("x: T")
169     .Input("y: T")
170     .Output("output: T")
171     .Attr("T: {bfloat16, float}")
172     .Attr("adj_x: bool = false")
173     .Attr("adj_y: bool = false")
174     .SetShapeFn(shape_inference::BatchMatMulV2Shape);
175 #endif  // INTEL_MKL
176 
177 // --------------------------------------------------------------------------
178 // Casting Ops
179 //
180 // NOTE: Only a smaller number of types are supported by
181 // Cast. The exact casting rule is TBD. The current
182 // implementation uses C++ static cast rules for numeric
183 // types, which may be changed in the future.
184 REGISTER_OP("Cast")
185     .Input("x: SrcT")
186     .Output("y: DstT")
187     .Attr("SrcT: type")
188     .Attr("DstT: type")
189     .Attr("Truncate: bool = false")
190     .SetShapeFn(shape_inference::UnchangedShape);
191 
192 REGISTER_OP("_HostCast")
193     .Input("x: SrcT")
194     .Output("y: DstT")
195     .Attr("SrcT: type")
196     .Attr("DstT: type")
197     .Attr("Truncate: bool = false")
198     .SetShapeFn(shape_inference::UnchangedShape)
199     .Doc(R"doc(
200 Cast x of type SrcT to y of DstT.
201 
202 _HostCast requires its input and produces its output in host memory.
203 )doc");
204 
205 // --------------------------------------------------------------------------
206 
207 REGISTER_OP("Abs")
208     .Input("x: T")
209     .Output("y: T")
210     .Attr("T: {bfloat16, half, float, double, int8, int16, int32, int64}")
211     .SetShapeFn(shape_inference::UnchangedShape);
212 
213 REGISTER_OP("ComplexAbs")
214     .Input("x: T")
215     .Output("y: Tout")
216     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
217     .Attr("Tout: {float, double} = DT_FLOAT")
218     .SetShapeFn(shape_inference::UnchangedShape);
219 
220 // Declares cwise unary operations signature: 't -> 't
221 #define UNARY()                                                            \
222   Input("x: T")                                                            \
223       .Output("y: T")                                                      \
224       .Attr(                                                               \
225           "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
226           "complex64, complex128}")                                        \
227       .SetShapeFn(shape_inference::UnchangedShape)
228 
229 #define UNARY_REAL()                              \
230   Input("x: T")                                   \
231       .Output("y: T")                             \
232       .Attr("T: {bfloat16, half, float, double}") \
233       .SetShapeFn(shape_inference::UnchangedShape)
234 
235 #define UNARY_COMPLEX()                                                  \
236   Input("x: T")                                                          \
237       .Output("y: T")                                                    \
238       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
239       .SetShapeFn(shape_inference::UnchangedShape)
240 
241 #define UNARY_GRADIENT_COMPLEX()                                         \
242   Input("y: T")                                                          \
243       .Input("dy: T")                                                    \
244       .Output("z: T")                                                    \
245       .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
246       .SetShapeFn(shape_inference::UnchangedShape)
247 
248 REGISTER_OP("Neg").UNARY();
249 
250 REGISTER_OP("Inv").UNARY();
251 
252 REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
253 
254 REGISTER_OP("Reciprocal").UNARY();
255 
256 REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
257 
258 REGISTER_OP("Square").UNARY();
259 
260 REGISTER_OP("Sqrt").UNARY_COMPLEX();
261 
262 REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
263 
264 REGISTER_OP("Rsqrt").UNARY_COMPLEX();
265 
266 REGISTER_OP("Round").UNARY();
267 
268 REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
269 
270 REGISTER_OP("Exp").UNARY_COMPLEX();
271 
272 REGISTER_OP("Expm1").UNARY_COMPLEX();
273 
274 REGISTER_OP("Log").UNARY_COMPLEX();
275 
276 REGISTER_OP("Log1p").UNARY_COMPLEX();
277 
278 REGISTER_OP("Sinh").UNARY_COMPLEX();
279 
280 REGISTER_OP("Cosh").UNARY_COMPLEX();
281 
282 REGISTER_OP("Tanh").UNARY_COMPLEX();
283 
284 REGISTER_OP("Asinh").UNARY_COMPLEX();
285 
286 REGISTER_OP("Acosh").UNARY_COMPLEX();
287 
288 REGISTER_OP("Atanh").UNARY_COMPLEX();
289 
290 REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
291 
292 REGISTER_OP("Lgamma").UNARY_REAL();
293 
294 REGISTER_OP("Digamma").UNARY_REAL();
295 
296 REGISTER_OP("Erf").UNARY_REAL();
297 REGISTER_OP("Erfinv").UNARY_REAL();
298 REGISTER_OP("Ndtri").UNARY_REAL();
299 REGISTER_OP("Erfc").UNARY_REAL();
300 
301 REGISTER_OP("Sigmoid").UNARY_COMPLEX();
302 
303 REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
304 
305 REGISTER_OP("Sin").UNARY_COMPLEX();
306 
307 REGISTER_OP("Cos").UNARY_COMPLEX();
308 
309 REGISTER_OP("Tan").UNARY();
310 
311 REGISTER_OP("Asin").UNARY();
312 
313 REGISTER_OP("Acos").UNARY();
314 
315 REGISTER_OP("Atan").UNARY();
316 
317 REGISTER_OP("_UnaryOpsComposition")
318     .Input("x: T")
319     .Output("y: T")
320     .Attr("T: {float, half, double}")
321     .Attr("op_names: list(string)")
322     .SetShapeFn(shape_inference::UnchangedShape)
323     .Doc(R"doc(
324 *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
325 expected to create these operators.
326 )doc");
327 
328 #undef UNARY
329 #undef UNARY_REAL
330 #undef UNARY_COMPLEX
331 
332 REGISTER_OP("IsNan")
333     .Input("x: T")
334     .Output("y: bool")
335     .Attr("T: {bfloat16, half, float, double}")
336     .SetShapeFn(shape_inference::UnchangedShape);
337 
338 REGISTER_OP("IsInf")
339     .Input("x: T")
340     .Output("y: bool")
341     .Attr("T: {bfloat16, half, float, double}")
342     .SetShapeFn(shape_inference::UnchangedShape);
343 
344 REGISTER_OP("IsFinite")
345     .Input("x: T")
346     .Output("y: bool")
347     .Attr("T: {bfloat16, half, float, double}")
348     .SetShapeFn(shape_inference::UnchangedShape);
349 
350 REGISTER_OP("Sign")
351     .Input("x: T")
352     .Output("y: T")
353     .Attr(
354         "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
355         "complex64, complex128}")
356     .SetShapeFn(shape_inference::UnchangedShape);
357 
358 REGISTER_OP("Floor")
359     .Input("x: T")
360     .Output("y: T")
361     .Attr("T: {bfloat16, half, float, double}")
362     .SetShapeFn(shape_inference::UnchangedShape);
363 
364 REGISTER_OP("Ceil")
365     .Input("x: T")
366     .Output("y: T")
367     .Attr("T: {bfloat16, half, float, double}")
368     .SetShapeFn(shape_inference::UnchangedShape);
369 
370 REGISTER_OP("Rint")
371     .Input("x: T")
372     .Output("y: T")
373     .Attr("T: {bfloat16, half, float, double}")
374     .SetShapeFn(shape_inference::UnchangedShape);
375 
376 // Declares cwise binary operations signature: 't, 't -> 't.
377 
378 #define BINARY_MORE()                                                          \
379   Input("x: T").Input("y: T").Output("z: T").Attr(                             \
380       "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
381       "uint32, uint64, int64, complex64, complex128}")
382 
383 #define BINARY_FEWER()                                               \
384   Input("x: T").Input("y: T").Output("z: T").Attr(                   \
385       "T: {bfloat16, half, float, double, int32, int64, complex64, " \
386       "complex128}")
387 
388 REGISTER_OP("Add")
389     .Input("x: T")
390     .Input("y: T")
391     .Output("z: T")
392     .Attr(
393         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
394         "complex64, complex128, string}")
395     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
396 
397 REGISTER_OP("AddV2")
398     .Input("x: T")
399     .Input("y: T")
400     .Output("z: T")
401     .Attr(
402         "T: {bfloat16, half, float, double, uint8, uint16, uint32, uint64, "
403         "int8, int16, int32, int64, complex64, complex128}")
404     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
405     .SetIsAggregate()
406     .SetIsCommutative();
407 
408 #ifdef INTEL_MKL
409 REGISTER_OP("_MklAdd")
410     .Input("x: T")
411     .Input("y: T")
412     .Input("mkl_x: uint8")
413     .Input("mkl_y: uint8")
414     .Output("z: T")
415     .Output("mkl_z: uint8")
416     .Attr(
417         "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
418         "complex128, string, bfloat16}")
419     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
420     .Doc(R"doc(
421 Returns `x` + `y` element-wise.
422 
423 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
424 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
425 )doc");
426 
427 REGISTER_OP("_MklAddV2")
428     .Input("x: T")
429     .Input("y: T")
430     .Input("mkl_x: uint8")
431     .Input("mkl_y: uint8")
432     .Output("z: T")
433     .Output("mkl_z: uint8")
434     .Attr(
435         "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
436         "complex64, complex128}")
437     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
438     .SetIsAggregate()
439     .SetIsCommutative()
440     .Doc(R"doc(
441 Returns `x` + `y` element-wise.
442 *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
443 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
444 )doc");
445 #endif  // INTEL_MKL
446 
447 REGISTER_OP("Sub")
448     .Input("x: T")
449     .Input("y: T")
450     .Output("z: T")
451     .Attr(
452         "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
453         "int64, complex64, complex128, uint32, uint64}")
454     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
455 
456 REGISTER_OP("_MklSub")
457     .BINARY_FEWER()
458     .Input("mkl_x: uint8")
459     .Input("mkl_y: uint8")
460     .Output("mkl_z: uint8")
461     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
462     .Doc(R"doc(
463 Returns x - y element-wise.
464 
465 *NOTE*: `Sub` supports broadcasting. More about broadcasting
466 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
467 )doc");
468 
469 REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
470     shape_inference::BroadcastBinaryOpShapeFn);
471 
472 REGISTER_OP("MulNoNan")
473     .Input("x: T")
474     .Input("y: T")
475     .Output("z: T")
476     .Attr("T: {bfloat16, half, float, double, complex64, complex128}")
477     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
478 
479 // Note: This op is not commutative w.r.t. to all its inputs.
480 REGISTER_OP("_MklMul")
481     .BINARY_MORE()
482     .Input("mkl_x: uint8")
483     .Input("mkl_y: uint8")
484     .Output("mkl_z: uint8")
485     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
486     .Doc(R"doc(
487 Returns x * y element-wise.
488 
489 *NOTE*: `Mul` supports broadcasting. More about broadcasting
490 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
491 )doc");
492 
493 REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
494     shape_inference::BroadcastBinaryOpShapeFn);
495 
496 REGISTER_OP("DivNoNan")
497     .Input("x: T")
498     .Input("y: T")
499     .Output("z: T")
500     .Attr("T: {half, float, bfloat16, double, complex64, complex128}")
501     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
502 
503 REGISTER_OP("FloorDiv")
504     .BINARY_MORE()
505     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
506 
507 REGISTER_OP("TruncateDiv")
508     .BINARY_MORE()
509     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
510 
511 REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
512     shape_inference::BroadcastBinaryOpShapeFn);
513 
514 // Note SquaredDifference implements conj(x - y)*(x - y).
515 REGISTER_OP("SquaredDifference")
516     .BINARY_FEWER()
517     .SetIsCommutative()
518     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
519 
520 // Note: This op is not commutative w.r.t. to all its inputs.
521 REGISTER_OP("_MklSquaredDifference")
522     .BINARY_FEWER()
523     .Input("mkl_x: uint8")
524     .Input("mkl_y: uint8")
525     .Output("mkl_z: uint8")
526     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
527     .Doc(R"doc(
528 Returns (x - y)(x - y) element-wise.
529 
530 *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
531 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
532 )doc");
533 
534 REGISTER_OP("Xlogy")
535     .Input("x: T")
536     .Input("y: T")
537     .Output("z: T")
538     .Attr("T: {half, float, double, complex64, complex128}")
539     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
540 
541 REGISTER_OP("Xlog1py")
542     .Input("x: T")
543     .Input("y: T")
544     .Output("z: T")
545     .Attr("T: {half, float, double, complex64, complex128}")
546     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
547 
548 REGISTER_OP("Xdivy")
549     .Input("x: T")
550     .Input("y: T")
551     .Output("z: T")
552     .Attr("T: {half, float, double, complex64, complex128}")
553     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
554 
555 #undef BINARY_FEWER
556 #undef BINARY_MORE
557 
558 REGISTER_OP("Maximum")
559     .Input("x: T")
560     .Input("y: T")
561     .Output("z: T")
562     .Attr(
563         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
564         "int32, uint32, int64, uint64}")
565     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
566 
567 // Note: This op is not commutative w.r.t. to all its inputs.
568 REGISTER_OP("_MklMaximum")
569     .Input("x: T")
570     .Input("y: T")
571     .Input("mkl_x: uint8")
572     .Input("mkl_y: uint8")
573     .Output("z: T")
574     .Output("mkl_z: uint8")
575     .Attr("T: {half, float, double, int32, int64, bfloat16}")
576     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
577     .Doc(R"doc(
578 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
579 
580 *NOTE*: `Maximum` supports broadcasting. More about broadcasting
581 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
582 )doc");
583 
584 REGISTER_OP("Minimum")
585     .Input("x: T")
586     .Input("y: T")
587     .Output("z: T")
588     .Attr(
589         "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
590         "int32, uint32, int64, uint64}")
591     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
592 
593 REGISTER_OP("Mod")
594     .Input("x: T")
595     .Input("y: T")
596     .Output("z: T")
597     .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
598     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
599 
600 REGISTER_OP("FloorMod")
601     .Input("x: T")
602     .Input("y: T")
603     .Output("z: T")
604     .Attr(
605         "T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64, "
606         "bfloat16, half, float, double}")
607     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
608 
609 REGISTER_OP("TruncateMod")
610     .Input("x: T")
611     .Input("y: T")
612     .Output("z: T")
613     .Attr("T: {int32, int64, bfloat16, half, float, double}")
614     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
615 
616 REGISTER_OP("Pow")
617     .Input("x: T")
618     .Input("y: T")
619     .Output("z: T")
620     .Attr(
621         "T: {bfloat16, float, half, double, int8, int16, int32, int64, "
622         "complex64, complex128}")
623     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
624 
625 REGISTER_OP("Igammac")
626     .Input("a: T")
627     .Input("x: T")
628     .Output("z: T")
629     .Attr("T: {bfloat16, half, float, double}")
630     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
631 
632 REGISTER_OP("Igamma")
633     .Input("a: T")
634     .Input("x: T")
635     .Output("z: T")
636     .Attr("T: {bfloat16, half, float, double}")
637     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
638 
639 REGISTER_OP("IgammaGradA")
640     .Input("a: T")
641     .Input("x: T")
642     .Output("z: T")
643     .Attr("T: {float, double}")
644     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
645 
646 REGISTER_OP("Zeta")
647     .Input("x: T")
648     .Input("q: T")
649     .Output("z: T")
650     .Attr("T: {float, double}")
651     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
652 
653 REGISTER_OP("Polygamma")
654     .Input("a: T")
655     .Input("x: T")
656     .Output("z: T")
657     .Attr("T: {float, double}")
658     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
659 
660 REGISTER_OP("Atan2")
661     .Input("y: T")
662     .Input("x: T")
663     .Output("z: T")
664     .Attr("T: {bfloat16, half, float, double}")
665     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
666 
667 REGISTER_OP("Betainc")
668     .Input("a: T")
669     .Input("b: T")
670     .Input("x: T")
671     .Output("z: T")
672     .Attr("T: {float, double}")
__anon99f049890202(InferenceContext* c) 673     .SetShapeFn([](InferenceContext* c) {
674       const int num_inputs = 3;
675       ShapeHandle output = c->UnknownShape();
676       int num_scalars = 0;
677       ShapeHandle some_non_scalar;
678       for (int i = 0; i < num_inputs; ++i) {
679         ShapeHandle in = c->input(i);
680         if (!c->RankKnown(in)) {
681           some_non_scalar = in;
682           // An input with unknown rank could be either a scalar (to be
683           // broadcast) or some other shape.
684         } else if (c->Rank(in) == 0) {
685           // Input is a scalar, it will be broadcast to the output shape.
686           ++num_scalars;
687         } else {
688           TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
689           some_non_scalar = output;
690         }
691       }
692 
693       if (num_scalars == num_inputs - 1) {
694         // If all but one input is known to be a scalar, then output is the
695         // remaining input.
696         output = some_non_scalar;
697       } else if (num_scalars == num_inputs) {
698         // If all are scalars, output is scalar; pick the first one arbitrarily.
699         output = c->input(0);
700       }
701 
702       c->set_output(0, output);
703       return Status::OK();
704     });
705 
706 // --------------------------------------------------------------------------
707 
708 // Declares cwise binary comparison operations signature: 't, 't -> bool,
709 // where 't has a natural total order.
710 #define COMPARISON()             \
711   Input("x: T")                  \
712       .Input("y: T")             \
713       .Output("z: bool")         \
714       .Attr("T: realnumbertype") \
715       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
716 
717 REGISTER_OP("Less").COMPARISON();
718 
719 REGISTER_OP("LessEqual").COMPARISON();
720 
721 REGISTER_OP("Greater").COMPARISON();
722 
723 REGISTER_OP("GreaterEqual").COMPARISON();
724 
725 #undef COMPARISON
726 
727 // --------------------------------------------------------------------------
728 
729 #define EQUALITY_COMPARISON()                                      \
730   Input("x: T")                                                    \
731       .Input("y: T")                                               \
732       .Output("z: bool")                                           \
733       .SetIsCommutative()                                          \
734       .Attr("T: type")                                             \
735       .Attr("incompatible_shape_error: bool = true")               \
736       .SetShapeFn([](InferenceContext* c) {                        \
737         ShapeHandle x = c->input(0);                               \
738         ShapeHandle y = c->input(1);                               \
739         ShapeHandle output;                                        \
740         bool incompatible_shape_error;                             \
741         TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error",  \
742                                       &incompatible_shape_error)); \
743         TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper(   \
744             c, x, y, incompatible_shape_error, &output));          \
745         c->set_output(0, output);                                  \
746         return Status::OK();                                       \
747       })
748 
749 REGISTER_OP("Equal").EQUALITY_COMPARISON();
750 
751 REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
752 
753 #undef EQUALITY_COMPARISON
754 
755 REGISTER_OP("ApproximateEqual")
756     .Input("x: T")
757     .Input("y: T")
758     .Output("z: bool")
759     .SetIsCommutative()
760     .Attr("T: numbertype")
761     .Attr("tolerance: float = 0.00001")
__anon99f049890302(InferenceContext* c) 762     .SetShapeFn([](InferenceContext* c) {
763       // The inputs 'x' and 'y' must have the same shape.
764       ShapeHandle data_x = c->input(0);
765       ShapeHandle data_y = c->input(1);
766       TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
767       return shape_inference::UnchangedShape(c);
768     });
769 
770 // --------------------------------------------------------------------------
771 
772 REGISTER_OP("LogicalNot")
773     .Input("x: bool")
774     .Output("y: bool")
775     .SetShapeFn(shape_inference::UnchangedShape);
776 
777 #define BINARY_LOGICAL()  \
778   Input("x: bool")        \
779       .Input("y: bool")   \
780       .Output("z: bool")  \
781       .SetIsCommutative() \
782       .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
783 
784 REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
785 
786 REGISTER_OP("LogicalOr").BINARY_LOGICAL();
787 
788 #undef BINARY_LOGICAL
789 
790 // --------------------------------------------------------------------------
791 
792 REGISTER_OP("Select")
793     .Input("condition: bool")
794     .Input("t: T")
795     .Input("e: T")
796     .Output("output: T")
797     .Attr("T: type")
__anon99f049890402(InferenceContext* c) 798     .SetShapeFn([](InferenceContext* c) {
799       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
800       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
801       // Merge handle shape and dtype if applicable.
802       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
803         const auto size = handle_data_1->size();
804         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
805         if (size != handle_data_2->size()) {
806           return errors::InvalidArgument(
807               "Trying to merge handles pointing to different numbers of "
808               "tensors.");
809         }
810 
811         for (int i = 0; i < size; ++i) {
812           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
813           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
814           if (s1.dtype != s2.dtype) {
815             // TODO(apassos) resolve this in the manner of b/32476923
816             return errors::InvalidArgument(
817                 "Trying to merge handles pointing to different dtypes.");
818           }
819           merged_handle_data[i].dtype = s1.dtype;
820           TF_RETURN_IF_ERROR(
821               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
822         }
823 
824         c->set_output_handle_shapes_and_types(0, merged_handle_data);
825       }
826 
827       // The inputs 'then' and 'else' must have the same shape.
828       ShapeHandle data = c->input(1);
829       ShapeHandle other = c->input(2);
830       TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
831 
832       // The input 'cond' must either have the same shape as 'then' and
833       // 'else', or be a vector if 'then' and 'else' are at least vectors.
834       ShapeHandle cond = c->input(0);
835 
836       if (!c->RankKnown(cond) || !c->RankKnown(data)) {
837         c->set_output(0, data);
838         return Status::OK();
839       }
840 
841       // rank of shape and data is known.
842 
843       const int32_t cond_rank = c->Rank(cond);
844       const int32_t data_rank = c->Rank(data);
845 
846       if (cond_rank == 0) {
847         // The rank of 'cond' is a scalar.
848         // t and e can have any shape.
849         c->set_output(0, data);
850         return Status::OK();
851       }
852 
853       if (cond_rank != 1) {
854         // If 'cond' is not a vector, and not a scalar,
855         // then shape must match 'then' and 'else'
856         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
857         c->set_output(0, data);
858         return Status::OK();
859       }
860 
861       if (data_rank == 0) {
862         // if 'then' and 'else' are scalar also the cond must be
863         TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
864         c->set_output(0, data);
865         return Status::OK();
866       }
867 
868       if (cond_rank == 1) {
869         // if the cond is a vector and the 'then' is not a scalar,
870         // the first dimension of 'then' and 'else'
871         TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
872         c->set_output(0, data);
873         return Status::OK();
874       }
875 
876       c->set_output(0, data);
877 
878       return Status::OK();
879     });
880 
881 REGISTER_OP("SelectV2")
882     .Input("condition: bool")
883     .Input("t: T")
884     .Input("e: T")
885     .Output("output: T")
886     .Attr("T: type")
__anon99f049890502(InferenceContext* c) 887     .SetShapeFn([](InferenceContext* c) {
888       auto* handle_data_1 = c->input_handle_shapes_and_types(1);
889       auto* handle_data_2 = c->input_handle_shapes_and_types(2);
890       // Merge handle shape and dtype if applicable.
891       if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
892         const auto size = handle_data_1->size();
893         std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
894         if (size != handle_data_2->size()) {
895           return errors::InvalidArgument(
896               "Trying to merge handles pointing to different numbers of "
897               "tensors.");
898         }
899 
900         for (int i = 0; i < size; ++i) {
901           const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
902           const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
903           if (s1.dtype != s2.dtype) {
904             // TODO(apassos) resolve this in the manner of b/32476923
905             return errors::InvalidArgument(
906                 "Trying to merge handles pointing to different dtypes.");
907           }
908           merged_handle_data[i].dtype = s1.dtype;
909           TF_RETURN_IF_ERROR(
910               c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
911         }
912 
913         c->set_output_handle_shapes_and_types(0, merged_handle_data);
914       }
915 
916       // The inputs 'cond', 'then', and 'else' must be broadcastable.
917       // TODO (yongtang): Consolidate 3-ary broadcast instead of
918       // multiple 2-ary broadcast.
919       ShapeHandle cond = c->input(0);
920       ShapeHandle then = c->input(1);
921       ShapeHandle else_ = c->input(2);
922       ShapeHandle other;
923       TF_RETURN_IF_ERROR(
924           BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
925       ShapeHandle output;
926       TF_RETURN_IF_ERROR(
927           BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
928       c->set_output(0, output);
929       return Status::OK();
930     });
931 
932 // --------------------------------------------------------------------------
933 
934 REGISTER_OP("MatMul")
935     .Input("a: T")
936     .Input("b: T")
937     .Output("product: T")
938     .Attr("transpose_a: bool = false")
939     .Attr("transpose_b: bool = false")
940     .Attr(
941         "T: {bfloat16, half, float, double, int32, int64, complex64, "
942         "complex128}")
943     .SetShapeFn(shape_inference::MatMulShape);
944 
945 #ifdef INTEL_MKL
946 REGISTER_OP("_MklMatMul")
947     .Input("a: T")
948     .Input("b: T")
949     .Output("product: T")
950     .Attr("transpose_a: bool = false")
951     .Attr("transpose_b: bool = false")
952     .Attr("T: {bfloat16, float}")
953     .SetShapeFn(shape_inference::MatMulShape);
954 #endif  // INTEL_MKL
955 
956 REGISTER_OP("SparseMatMul")
957     .Input("a: Ta")
958     .Input("b: Tb")
959     .Output("product: float")
960     .Attr("transpose_a: bool = false")
961     .Attr("transpose_b: bool = false")
962     .Attr("a_is_sparse: bool = false")
963     .Attr("b_is_sparse: bool = false")
964     .Attr("Ta: {float, bfloat16} = DT_FLOAT")
965     .Attr("Tb: {float, bfloat16} = DT_FLOAT")
966     .SetShapeFn(shape_inference::MatMulShape);
967 
968 REGISTER_OP("_FusedMatMul")
969     .Input("a: T")
970     .Input("b: T")
971     .Input("args: num_args * T")
972     .Output("product: T")
973     .Attr("transpose_a: bool = false")
974     .Attr("transpose_b: bool = false")
975     .Attr("T: {bfloat16, float}")
976     .Attr("num_args: int >= 0")
977     .Attr("fused_ops: list(string) = []")
978     // Attributes for the FusedBatchNorm ----------- //
979     .Attr("epsilon: float = 0.0001")
980     // Attributes for the LeakyRelu ---------------- //
981     .Attr("leakyrelu_alpha: float = 0.2")
982     // --------------------------------------------- //
983     .SetShapeFn(shape_inference::MatMulShape)
984     .Doc(R"doc(
985 Performs a MatMul followed by a specified series of operations.
986 
987 The inputs to the MatMul are specified by `a` and `b`. The series of operations
988 that follows is specified by the `fused_ops` attribute, which is a list of TF op
989 names specified as strings (e.g. "Relu"). They are performed in order, where the
990 (first) input to each op is the output of the preceding op. The first input and
991 the output of each fused_op must be of type T.
992 
993 Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
994 where A is one of {"Elu","Relu","Relu6"}.
995 
996 * The first input to BiasAdd is the Conv2D result, and the additional BiasAdd
997 input is specified by `args`.
998 * If there is an op A specified, the output of the BiasAdd is the input to op A,
999 and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
1000 _FusedConv2D output.
1001 
1002 *NOTE*: Do not invoke this operator directly in Python. Grappler is
1003 expected to create these operators.
1004 )doc");
1005 
1006 // --------------------------------------------------------------------------
1007 
1008 // For operations where the output is a reduction function along some
1009 // dimensions of the input.
1010 REGISTER_OP("Sum")
1011     .Input("input: T")
1012     .Input("reduction_indices: Tidx")
1013     .Output("output: T")
1014     .Attr("keep_dims: bool = false")
1015     .Attr("T: numbertype")
1016     .Attr("Tidx: {int32, int64} = DT_INT32")
1017     .SetShapeFn(shape_inference::ReductionShape);
1018 
1019 REGISTER_OP("EuclideanNorm")
1020     .Input("input: T")
1021     .Input("reduction_indices: Tidx")
1022     .Output("output: T")
1023     .Attr("keep_dims: bool = false")
1024     .Attr("T: numbertype")
1025     .Attr("Tidx: {int32, int64} = DT_INT32")
1026     .SetShapeFn(shape_inference::ReductionShape);
1027 
1028 REGISTER_OP("Mean")
1029     .Input("input: T")
1030     .Input("reduction_indices: Tidx")
1031     .Output("output: T")
1032     .Attr("keep_dims: bool = false")
1033     .Attr("T: numbertype")
1034     .Attr("Tidx: {int32, int64} = DT_INT32")
1035     .SetShapeFn(shape_inference::ReductionShape);
1036 
1037 REGISTER_OP("Prod")
1038     .Input("input: T")
1039     .Input("reduction_indices: Tidx")
1040     .Output("output: T")
1041     .Attr("keep_dims: bool = false")
1042     .Attr("T: numbertype")
1043     .Attr("Tidx: {int32, int64} = DT_INT32")
1044     .SetShapeFn(shape_inference::ReductionShape);
1045 
1046 REGISTER_OP("Min")
1047     .Input("input: T")
1048     .Input("reduction_indices: Tidx")
1049     .Output("output: T")
1050     .Attr("keep_dims: bool = false")
1051     .Attr("T: {realnumbertype, quantizedtype}")
1052     .Attr("Tidx: {int32, int64} = DT_INT32")
1053     .SetShapeFn(shape_inference::ReductionShape);
1054 
1055 REGISTER_OP("Max")
1056     .Input("input: T")
1057     .Input("reduction_indices: Tidx")
1058     .Output("output: T")
1059     .Attr("keep_dims: bool = false")
1060     .Attr("T: {realnumbertype, quantizedtype}")
1061     .Attr("Tidx: {int32, int64} = DT_INT32")
1062     .SetShapeFn(shape_inference::ReductionShape);
1063 
1064 namespace {
1065 
ArgOpShape(shape_inference::InferenceContext * c)1066 Status ArgOpShape(shape_inference::InferenceContext* c) {
1067   ShapeHandle dimension_shape;
1068   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
1069 
1070   ShapeHandle input_shape = c->input(0);
1071   if (!c->RankKnown(input_shape)) {
1072     return shape_inference::UnknownShape(c);
1073   }
1074 
1075   const int32_t input_rank = c->Rank(input_shape);
1076   if (input_rank <= 1) {
1077     // Reducing a scalar/vector must return a scalar.
1078     return shape_inference::ScalarShape(c);
1079   }
1080 
1081   const Tensor* dim_t = c->input_tensor(1);
1082   if (dim_t == nullptr) {
1083     // We don't know the value of the dimension, but we
1084     // know the rank of the input, so return the correct
1085     // rank with unknown dimensions.
1086     std::vector<DimensionHandle> dims(input_rank - 1);
1087     for (int i = 0; i < dims.size(); ++i) {
1088       dims[i] = c->UnknownDim();
1089     }
1090 
1091     c->set_output(0, c->MakeShape(dims));
1092     return Status::OK();
1093   }
1094 
1095   int64_t dimension_val;
1096   if (dim_t->dtype() == DT_INT32) {
1097     dimension_val = dim_t->scalar<int32>()();
1098   } else {
1099     dimension_val = dim_t->scalar<int64>()();
1100   }
1101 
1102   int64_t axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
1103   if (axis < 0 || axis >= input_rank) {
1104     return errors::InvalidArgument(
1105         "Dimension (", dimension_val, ") must be in the range [", -input_rank,
1106         ", ", input_rank, "), where ", input_rank,
1107         " is the number of dimensions in the input.");
1108   }
1109 
1110   // Return the input shape without the dimension being reduced.
1111   std::vector<DimensionHandle> dims;
1112   for (int i = 0; i < input_rank; ++i) {
1113     if (axis != i) {
1114       dims.emplace_back(c->Dim(input_shape, i));
1115     }
1116   }
1117   c->set_output(0, c->MakeShape(dims));
1118   return Status::OK();
1119 }
1120 
1121 }  // namespace
1122 
1123 REGISTER_OP("ArgMax")
1124     .Input("input: T")
1125     .Input("dimension: Tidx")
1126     .Output("output: output_type")
1127     .Attr("T: {numbertype, bool}")
1128     .Attr("Tidx: {int32, int64} = DT_INT32")
1129     .Attr("output_type: {int32, int64} = DT_INT64")
1130     .SetShapeFn(ArgOpShape);
1131 
1132 REGISTER_OP("ArgMin")
1133     .Input("input: T")
1134     .Input("dimension: Tidx")
1135     .Output("output: output_type")
1136     .Attr("T: {numbertype, bool}")
1137     .Attr("Tidx: {int32, int64} = DT_INT32")
1138     .Attr("output_type: {int32, int64} = DT_INT64")
1139     .SetShapeFn(ArgOpShape);
1140 
1141 namespace {
1142 
SegmentReductionShapeFn(InferenceContext * c)1143 Status SegmentReductionShapeFn(InferenceContext* c) {
1144   ShapeHandle data_shape;
1145   ShapeHandle segment_ids_shape;
1146   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1147   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
1148 
1149   ShapeHandle subshape;
1150   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1151 
1152   ShapeHandle out;
1153   TF_RETURN_IF_ERROR(
1154       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1155   c->set_output(0, out);
1156   return Status::OK();
1157 }
1158 
SparseSegmentReductionShapeFn(InferenceContext * c)1159 Status SparseSegmentReductionShapeFn(InferenceContext* c) {
1160   ShapeHandle data_shape;
1161   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1162 
1163   ShapeHandle indices_shape;
1164   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1165 
1166   ShapeHandle segment_ids_shape;
1167   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1168 
1169   // indices and segment_ids should merge cleanly.
1170   ShapeHandle unused;
1171   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1172 
1173   ShapeHandle subshape;
1174   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1175 
1176   ShapeHandle out;
1177   TF_RETURN_IF_ERROR(
1178       c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1179   c->set_output(0, out);
1180   return Status::OK();
1181 }
1182 
SparseSegmentReductionGradShapeFn(InferenceContext * c)1183 Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
1184   ShapeHandle data_shape;
1185   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1186 
1187   ShapeHandle indices_shape;
1188   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1189 
1190   // indices and segment_ids should merge cleanly.
1191   ShapeHandle unused;
1192   TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
1193 
1194   // output_dim0 should be a scalar
1195   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1196 
1197   ShapeHandle subshape;
1198   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1199 
1200   const Tensor* dim0 = c->input_tensor(3);
1201   ShapeHandle dim0_shape;
1202   if (dim0 == nullptr) {
1203     // We don't have the value at inference time, so the output
1204     // shape is unknown.
1205     dim0_shape = c->Vector(InferenceContext::kUnknownDim);
1206   } else {
1207     auto dim0_value = dim0->scalar<int32>()();
1208     if (dim0_value < 0) {
1209       return errors::InvalidArgument(
1210           "Cannot specify a negative value for output_dim0");
1211     }
1212     dim0_shape = c->Vector(dim0_value);
1213   }
1214 
1215   ShapeHandle out;
1216   TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
1217   c->set_output(0, out);
1218   return Status::OK();
1219 }
1220 
SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext * c)1221 Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
1222   ShapeHandle data_shape;
1223   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1224 
1225   ShapeHandle indices_shape;
1226   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1227 
1228   ShapeHandle segment_ids_shape;
1229   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1230 
1231   ShapeHandle num_segments_shape;
1232   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
1233 
1234   // indices and segment_ids should merge cleanly.
1235   ShapeHandle unused;
1236   TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1237 
1238   ShapeHandle subshape;
1239   TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1240 
1241   ShapeHandle out;
1242   const Tensor* dim0 = c->input_tensor(3);
1243   if (dim0 == nullptr) {
1244     // We don't have the value at inference time, so the output
1245     // shape is unknown.
1246     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
1247                                       subshape, &out));
1248   } else {
1249     auto dim0_value = dim0->scalar<int32>()();
1250     if (dim0_value < 0) {
1251       return errors::InvalidArgument(
1252           "Cannot specify a negative value for num_segments");
1253     }
1254     TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
1255   }
1256   c->set_output(0, out);
1257   return Status::OK();
1258 }
1259 }  // namespace
1260 
1261 REGISTER_OP("SegmentSum")
1262     .Input("data: T")
1263     .Input("segment_ids: Tindices")
1264     .Output("output: T")
1265     .Attr("T: numbertype")
1266     .Attr("Tindices: {int32,int64}")
1267     .SetShapeFn(SegmentReductionShapeFn);
1268 
1269 REGISTER_OP("SegmentMean")
1270     .Input("data: T")
1271     .Input("segment_ids: Tindices")
1272     .Output("output: T")
1273     .Attr("T: numbertype")
1274     .Attr("Tindices: {int32,int64}")
1275     .SetShapeFn(SegmentReductionShapeFn);
1276 
1277 REGISTER_OP("SegmentProd")
1278     .Input("data: T")
1279     .Input("segment_ids: Tindices")
1280     .Output("output: T")
1281     .Attr("T: numbertype")
1282     .Attr("Tindices: {int32,int64}")
1283     .SetShapeFn(SegmentReductionShapeFn);
1284 
1285 REGISTER_OP("SegmentMin")
1286     .Input("data: T")
1287     .Input("segment_ids: Tindices")
1288     .Output("output: T")
1289     .Attr("T: realnumbertype")
1290     .Attr("Tindices: {int32,int64}")
1291     .SetShapeFn(SegmentReductionShapeFn);
1292 
1293 REGISTER_OP("SegmentMax")
1294     .Input("data: T")
1295     .Input("segment_ids: Tindices")
1296     .Output("output: T")
1297     .Attr("T: realnumbertype")
1298     .Attr("Tindices: {int32,int64}")
1299     .SetShapeFn(SegmentReductionShapeFn);
1300 
1301 REGISTER_OP("UnsortedSegmentSum")
1302     .Input("data: T")
1303     .Input("segment_ids: Tindices")
1304     .Input("num_segments: Tnumsegments")
1305     .Output("output: T")
1306     .Attr("T: numbertype")
1307     .Attr("Tindices: {int32,int64}")
1308     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1309     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1310 
1311 REGISTER_OP("UnsortedSegmentMax")
1312     .Input("data: T")
1313     .Input("segment_ids: Tindices")
1314     .Input("num_segments: Tnumsegments")
1315     .Output("output: T")
1316     .Attr("T: realnumbertype")
1317     .Attr("Tindices: {int32,int64}")
1318     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1319     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1320 
1321 REGISTER_OP("UnsortedSegmentMin")
1322     .Input("data: T")
1323     .Input("segment_ids: Tindices")
1324     .Input("num_segments: Tnumsegments")
1325     .Output("output: T")
1326     .Attr("T: realnumbertype")
1327     .Attr("Tindices: {int32,int64}")
1328     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1329     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1330 
1331 REGISTER_OP("UnsortedSegmentProd")
1332     .Input("data: T")
1333     .Input("segment_ids: Tindices")
1334     .Input("num_segments: Tnumsegments")
1335     .Output("output: T")
1336     .Attr("T: numbertype")
1337     .Attr("Tindices: {int32,int64}")
1338     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1339     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1340 
1341 REGISTER_OP("SparseSegmentSum")
1342     .Input("data: T")
1343     .Input("indices: Tidx")
1344     .Input("segment_ids: Tsegmentids")
1345     .Output("output: T")
1346     .Attr("T: realnumbertype")
1347     .Attr("Tidx: {int32, int64} = DT_INT32")
1348     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1349     .SetShapeFn(SparseSegmentReductionShapeFn);
1350 
1351 REGISTER_OP("SparseSegmentSumWithNumSegments")
1352     .Input("data: T")
1353     .Input("indices: Tidx")
1354     .Input("segment_ids: Tsegmentids")
1355     .Input("num_segments: Tnumsegments")
1356     .Output("output: T")
1357     .Attr("T: realnumbertype")
1358     .Attr("Tidx: {int32, int64} = DT_INT32")
1359     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1360     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1361     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1362 
1363 REGISTER_OP("SparseSegmentSumGrad")
1364     .Input("grad: T")
1365     .Input("indices: Tidx")
1366     .Input("segment_ids: Tsegmentids")
1367     .Input("output_dim0: int32")
1368     .Output("output: T")
1369     .Attr("T: {bfloat16, half, float, double}")
1370     .Attr("Tidx: {int32, int64} = DT_INT32")
1371     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1372     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1373 
1374 REGISTER_OP("SparseSegmentMean")
1375     .Input("data: T")
1376     .Input("indices: Tidx")
1377     .Input("segment_ids: Tsegmentids")
1378     .Output("output: T")
1379     .Attr("T: {bfloat16, half, float, double}")
1380     .Attr("Tidx: {int32, int64} = DT_INT32")
1381     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1382     .SetShapeFn(SparseSegmentReductionShapeFn);
1383 
1384 REGISTER_OP("SparseSegmentMeanWithNumSegments")
1385     .Input("data: T")
1386     .Input("indices: Tidx")
1387     .Input("segment_ids: Tsegmentids")
1388     .Input("num_segments: Tnumsegments")
1389     .Output("output: T")
1390     .Attr("T: {bfloat16, half, float, double}")
1391     .Attr("Tidx: {int32, int64} = DT_INT32")
1392     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1393     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1394     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1395 
1396 REGISTER_OP("SparseSegmentMeanGrad")
1397     .Input("grad: T")
1398     .Input("indices: Tidx")
1399     .Input("segment_ids: Tsegmentids")
1400     .Input("output_dim0: int32")
1401     .Output("output: T")
1402     .Attr("T: {bfloat16, half, float, double}")
1403     .Attr("Tidx: {int32, int64} = DT_INT32")
1404     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1405     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1406 
1407 REGISTER_OP("SparseSegmentSqrtN")
1408     .Input("data: T")
1409     .Input("indices: Tidx")
1410     .Input("segment_ids: Tsegmentids")
1411     .Output("output: T")
1412     .Attr("T: {bfloat16, half, float, double}")
1413     .Attr("Tidx: {int32, int64} = DT_INT32")
1414     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1415     .SetShapeFn(SparseSegmentReductionShapeFn);
1416 
1417 REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
1418     .Input("data: T")
1419     .Input("indices: Tidx")
1420     .Input("segment_ids: Tsegmentids")
1421     .Input("num_segments: Tnumsegments")
1422     .Output("output: T")
1423     .Attr("T: {bfloat16, half, float, double}")
1424     .Attr("Tidx: {int32, int64} = DT_INT32")
1425     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1426     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1427     .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1428 
1429 REGISTER_OP("SparseSegmentSqrtNGrad")
1430     .Input("grad: T")
1431     .Input("indices: Tidx")
1432     .Input("segment_ids: Tsegmentids")
1433     .Input("output_dim0: int32")
1434     .Output("output: T")
1435     .Attr("T: {bfloat16, half, float, double}")
1436     .Attr("Tidx: {int32, int64} = DT_INT32")
1437     .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1438     .SetShapeFn(SparseSegmentReductionGradShapeFn);
1439 
1440 REGISTER_OP("All")
1441     .Input("input: bool")
1442     .Input("reduction_indices: Tidx")
1443     .Output("output: bool")
1444     .Attr("keep_dims: bool = false")
1445     .Attr("Tidx: {int32, int64} = DT_INT32")
1446     .SetShapeFn(shape_inference::ReductionShape);
1447 
1448 REGISTER_OP("Any")
1449     .Input("input: bool")
1450     .Input("reduction_indices: Tidx")
1451     .Attr("keep_dims: bool = false")
1452     .Output("output: bool")
1453     .Attr("Tidx: {int32, int64} = DT_INT32")
1454     .SetShapeFn(shape_inference::ReductionShape);
1455 
1456 // --------------------------------------------------------------------------
1457 
1458 namespace {
1459 
1460 template <typename T>
RangeSize(const Tensor * start_t,const Tensor * limit_t,const Tensor * delta_t,InferenceContext * const c)1461 Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
1462                  const Tensor* delta_t, InferenceContext* const c) {
1463   T start = start_t->scalar<T>()();
1464   T limit = limit_t->scalar<T>()();
1465   T delta = delta_t->scalar<T>()();
1466   if (start > limit && delta > T(0)) {
1467     return errors::InvalidArgument(
1468         "Requires start <= limit when delta > 0: ", start, "/", limit);
1469   }
1470   if (start < limit && delta < T(0)) {
1471     return errors::InvalidArgument(
1472         "Requires start >= limit when delta < 0: ", start, "/", limit);
1473   }
1474   if (delta == T(0)) {
1475     return errors::InvalidArgument("Requires delta != 0");
1476   }
1477 
1478   auto size = (std::is_integral<T>::value
1479                    ? ((Eigen::numext::abs(limit - start) +
1480                        Eigen::numext::abs(delta) - T(1)) /
1481                       Eigen::numext::abs(delta))
1482                    : (Eigen::numext::ceil(
1483                          Eigen::numext::abs((limit - start) / delta))));
1484   c->set_output(0, c->Vector(static_cast<int64>(size)));
1485   return Status::OK();
1486 }
1487 
1488 }  // namespace
1489 
1490 REGISTER_OP("Range")
1491     .Input("start: Tidx")
1492     .Input("limit: Tidx")
1493     .Input("delta: Tidx")
1494     .Output("output: Tidx")
1495     .Attr(
1496         "Tidx: "
1497         "{bfloat16, half, float, double, int8, int16, int32, int64, uint32} = "
1498         "DT_INT32")
__anon99f049890902(InferenceContext* c) 1499     .SetShapeFn([](InferenceContext* c) {
1500       ShapeHandle unused;
1501       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1502                                       " for 'start'");
1503       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1504                                       " for 'limit'");
1505       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1506                                       " for 'delta'");
1507       const Tensor* start_t = c->input_tensor(0);
1508       const Tensor* limit_t = c->input_tensor(1);
1509       const Tensor* delta_t = c->input_tensor(2);
1510       DataType dtype;
1511       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1512       if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
1513         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1514         return Status::OK();
1515       }
1516       if (dtype == DT_INT32) {
1517         return RangeSize<int32>(start_t, limit_t, delta_t, c);
1518       } else if (dtype == DT_INT16) {
1519         return RangeSize<int16>(start_t, limit_t, delta_t, c);
1520       } else if (dtype == DT_INT8) {
1521         return RangeSize<int8>(start_t, limit_t, delta_t, c);
1522       } else if (dtype == DT_INT64) {
1523         return RangeSize<int64>(start_t, limit_t, delta_t, c);
1524       } else if (dtype == DT_UINT32) {
1525         return RangeSize<uint32>(start_t, limit_t, delta_t, c);
1526       } else if (dtype == DT_FLOAT) {
1527         return RangeSize<float>(start_t, limit_t, delta_t, c);
1528       } else if (dtype == DT_DOUBLE) {
1529         return RangeSize<double>(start_t, limit_t, delta_t, c);
1530       } else if (dtype == DT_BFLOAT16) {
1531         return RangeSize<bfloat16>(start_t, limit_t, delta_t, c);
1532       } else {
1533         return errors::InvalidArgument("Unsupported dtype", dtype);
1534       }
1535       return Status::OK();
1536     });
1537 
1538 REGISTER_OP("LinSpace")
1539     .Input("start: T")
1540     .Input("stop: T")
1541     .Input("num: Tidx")
1542     .Output("output: T")
1543     .Attr("T: {bfloat16, half, float, double}")
1544     .Attr("Tidx: {int32, int64} = DT_INT32")
__anon99f049890a02(InferenceContext* c) 1545     .SetShapeFn([](InferenceContext* c) {
1546       ShapeHandle unused;
1547       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1548                                       " for 'start'");
1549       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1550                                       " for 'stop'");
1551       TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1552                                       " for 'num'");
1553       const Tensor* num_t = c->input_tensor(2);
1554       if (num_t == nullptr) {
1555         c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1556         return Status::OK();
1557       }
1558 
1559       int64_t num;
1560       if (num_t->dtype() == DT_INT32) {
1561         num = num_t->scalar<int32>()();
1562       } else {
1563         num = num_t->scalar<int64>()();
1564       }
1565       if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
1566       c->set_output(0, c->Vector(num));
1567       return Status::OK();
1568     });
1569 
1570 REGISTER_OP("Complex")
1571     .Input("real: T")
1572     .Input("imag: T")
1573     .Output("out: Tout")
1574     .Attr("T: {float, double} = DT_FLOAT")
1575     .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
1576     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1577 
1578 REGISTER_OP("Real")
1579     .Input("input: T")
1580     .Output("output: Tout")
1581     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1582     .Attr("Tout: {float, double} = DT_FLOAT")
1583     .SetShapeFn(shape_inference::UnchangedShape);
1584 
1585 REGISTER_OP("Imag")
1586     .Input("input: T")
1587     .Output("output: Tout")
1588     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1589     .Attr("Tout: {float, double} = DT_FLOAT")
1590     .SetShapeFn(shape_inference::UnchangedShape);
1591 
1592 REGISTER_OP("Angle")
1593     .Input("input: T")
1594     .Output("output: Tout")
1595     .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1596     .Attr("Tout: {float, double} = DT_FLOAT")
1597     .SetShapeFn(shape_inference::UnchangedShape);
1598 
1599 REGISTER_OP("Conj")
1600     .Input("input: T")
1601     .Output("output: T")
1602     .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
__anon99f049890b02(InferenceContext* c) 1603     .SetShapeFn([](InferenceContext* c) {
1604       c->set_output(0, c->input(0));
1605       auto* handle_data = c->input_handle_shapes_and_types(0);
1606       if (handle_data != nullptr) {
1607         c->set_output_handle_shapes_and_types(0, *handle_data);
1608       }
1609       return Status::OK();
1610     });
1611 
1612 // --------------------------------------------------------------------------
1613 
1614 REGISTER_OP("Cross")
1615     .Input("a: T")
1616     .Input("b: T")
1617     .Output("product: T")
1618     .Attr("T: realnumbertype")
__anon99f049890c02(InferenceContext* c) 1619     .SetShapeFn([](InferenceContext* c) {
1620       ShapeHandle a_shape;
1621       ShapeHandle b_shape;
1622       // * Input rank >= 1.
1623       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
1624       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
1625 
1626       // * Both inputs have the same shape.
1627       TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
1628 
1629       // * input_shape[-1] == 3.
1630       if (c->RankKnown(a_shape)) {
1631         int rank = c->Rank(a_shape);
1632         auto dim = c->Dim(a_shape, rank - 1);
1633         TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
1634       }
1635       c->set_output(0, a_shape);
1636       return Status::OK();
1637     });
1638 
1639 // --------------------------------------------------------------------------
1640 
1641 REGISTER_OP("HistogramFixedWidth")
1642     .Input("values: T")
1643     .Input("value_range: T")
1644     .Input("nbins: int32")
1645     .Output("out: dtype")
1646     .Attr("T: {int32, int64, float32, float64}")
1647     .Attr("dtype: {int32, int64} = DT_INT32")
__anon99f049890d02(InferenceContext* c) 1648     .SetShapeFn([](InferenceContext* c) {
1649       // value_range should be a vector.
1650       ShapeHandle value_range_shape;
1651       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
1652       // value_range should have two elements.
1653       DimensionHandle unused;
1654       TF_RETURN_IF_ERROR(
1655           c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
1656       // nbins should be a scalar.
1657       ShapeHandle nbins_shape;
1658       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
1659 
1660       // If nbins is available, set the shape from nbins.
1661       const Tensor* nbins_input = c->input_tensor(2);
1662       if (nbins_input != nullptr) {
1663         int64_t nbins;
1664         TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
1665         // nbins has to be positive.
1666         if (nbins <= 0) {
1667           return errors::InvalidArgument("Requires nbins > 0: ", nbins);
1668         }
1669         c->set_output(0, c->Vector(nbins));
1670       } else {
1671         c->set_output(0, c->UnknownShapeOfRank(1));
1672       }
1673       return Status::OK();
1674     });
1675 
1676 REGISTER_OP("Bincount")
1677     .Input("arr: int32")
1678     .Input("size: int32")
1679     .Input("weights: T")
1680     .Attr("T: {int32, int64, float32, float64}")
1681     .Output("bins: T")
__anon99f049890e02(InferenceContext* c) 1682     .SetShapeFn([](InferenceContext* c) {
1683       ShapeHandle unused;
1684       // The input `size` must be a scalar.
1685       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1686 
1687       const Tensor* size_tensor = c->input_tensor(1);
1688       if (size_tensor == nullptr) {
1689         // Return unknown shape if size is not known.
1690         c->set_output(0, c->UnknownShapeOfRank(1));
1691         return Status::OK();
1692       }
1693 
1694       // Return `[size]` shape if size is known.
1695       int32_t size_val = size_tensor->scalar<int32>()();
1696       if (size_val < 0) {
1697         return errors::InvalidArgument("size (", size_val,
1698                                        ") must be non-negative");
1699       }
1700       c->set_output(0, c->MakeShape({size_val}));
1701       return Status::OK();
1702     });
1703 
1704 REGISTER_OP("DenseBincount")
1705     .Input("input: Tidx")
1706     .Input("size: Tidx")
1707     .Input("weights: T")
1708     .Attr("Tidx: {int32, int64}")
1709     .Attr("T: {int32, int64, float32, float64}")
1710     .Attr("binary_output: bool = false")
1711     .Output("output: T")
__anon99f049890f02(InferenceContext* c) 1712     .SetShapeFn([](InferenceContext* c) {
1713       ShapeHandle unused;
1714       // The input `input` must be at most matrix.
1715       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 2, &unused));
1716       // The input `size` must be a scalar.
1717       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1718 
1719       const Tensor* size_tensor = c->input_tensor(1);
1720       if (size_tensor == nullptr) {
1721         // Return unknown shape if size is not known.
1722         c->set_output(0, c->UnknownShape());
1723         return Status::OK();
1724       }
1725 
1726       int64_t size_val;
1727       DataType dtype;
1728       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1729       if (dtype == DT_INT32) {
1730         size_val = static_cast<int64>(size_tensor->scalar<int32>()());
1731       } else if (dtype == DT_INT64) {
1732         size_val = size_tensor->scalar<int64>()();
1733       } else {
1734         return errors::InvalidArgument("size dtype must be int32 or int64");
1735       }
1736       // Return `[size]` shape if size is known.
1737       if (size_val < 0) {
1738         return errors::InvalidArgument("size (", size_val,
1739                                        ") must be non-negative");
1740       }
1741       if (c->Rank(c->input(0)) == 1) {
1742         c->set_output(0, c->MakeShape({size_val}));
1743       } else if (c->Rank(c->input(0)) == 2) {
1744         c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val}));
1745       }
1746       return Status::OK();
1747     });
1748 
1749 REGISTER_OP("SparseBincount")
1750     .Input("indices: int64")
1751     .Input("values: Tidx")
1752     .Input("dense_shape: int64")
1753     .Input("size: Tidx")
1754     .Input("weights: T")
1755     .Attr("Tidx: {int32, int64}")
1756     .Attr("T: {int32, int64, float32, float64}")
1757     .Attr("binary_output: bool = false")
1758     .Output("output: T")
__anon99f049891002(InferenceContext* c) 1759     .SetShapeFn([](InferenceContext* c) {
1760       const Tensor* size_tensor = c->input_tensor(3);
1761       if (size_tensor == nullptr) {
1762         // Return unknown shape if size is not known.
1763         c->set_output(0, c->UnknownShape());
1764         return Status::OK();
1765       }
1766 
1767       int64_t size_val;
1768       DataType dtype;
1769       TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1770       if (dtype == DT_INT32) {
1771         size_val = static_cast<int64>(size_tensor->scalar<int32>()());
1772       } else if (dtype == DT_INT64) {
1773         size_val = size_tensor->scalar<int64>()();
1774       } else {
1775         return errors::InvalidArgument("size dtype must be int32 or int64");
1776       }
1777       // Return `[size]` shape if size is known.
1778       if (size_val < 0) {
1779         return errors::InvalidArgument("size (", size_val,
1780                                        ") must be non-negative");
1781       }
1782 
1783       const Tensor* shape_tensor = c->input_tensor(2);
1784       if (shape_tensor == nullptr) {
1785         // Return unknown shape if size is not known.
1786         c->set_output(0, c->UnknownShape());
1787         return Status::OK();
1788       }
1789       if (shape_tensor->NumElements() == 1) {
1790         c->set_output(0, c->MakeShape({size_val}));
1791       } else if (shape_tensor->NumElements() == 2) {
1792         c->set_output(0,
1793                       c->MakeShape({shape_tensor->flat<int64>()(0), size_val}));
1794       } else {
1795         return errors::InvalidArgument("Input must be less than rank 2");
1796       }
1797       return Status::OK();
1798     });
1799 
1800 REGISTER_OP("RaggedBincount")
1801     .Input("splits: int64")
1802     .Input("values: Tidx")
1803     .Input("size: Tidx")
1804     .Input("weights: T")
1805     .Attr("Tidx: {int32, int64}")
1806     .Attr("T: {int32, int64, float32, float64}")
1807     .Attr("binary_output: bool = false")
1808     .Output("output: T")
__anon99f049891102(InferenceContext* c) 1809     .SetShapeFn([](InferenceContext* c) {
1810       c->set_output(0, c->UnknownShape());
1811       return Status::OK();
1812     });
1813 
1814 REGISTER_OP("Cumsum")
1815     .Input("x: T")
1816     .Input("axis: Tidx")
1817     .Attr("exclusive: bool = false")
1818     .Attr("reverse: bool = false")
1819     .Output("out: T")
1820     .Attr("T: numbertype")
1821     .Attr("Tidx: {int32, int64} = DT_INT32")
1822     .SetShapeFn(shape_inference::UnchangedShape);
1823 
1824 REGISTER_OP("Cumprod")
1825     .Input("x: T")
1826     .Input("axis: Tidx")
1827     .Attr("exclusive: bool = false")
1828     .Attr("reverse: bool = false")
1829     .Output("out: T")
1830     .Attr("T: numbertype")
1831     .Attr("Tidx: {int32, int64} = DT_INT32")
1832     .SetShapeFn(shape_inference::UnchangedShape);
1833 
1834 REGISTER_OP("CumulativeLogsumexp")
1835     .Input("x : T")
1836     .Input("axis: Tidx")
1837     .Attr("exclusive: bool = false")
1838     .Attr("reverse: bool = false")
1839     .Output("out: T")
1840     .Attr("T: {float16, float32, float64}")
1841     .Attr("Tidx: {int32, int64} = DT_INT32")
1842     .SetShapeFn(shape_inference::UnchangedShape);
1843 
1844 REGISTER_OP("QuantizedMatMul")
1845     .Input("a: T1")
1846     .Input("b: T2")
1847     .Input("min_a: float")
1848     .Input("max_a: float")
1849     .Input("min_b: float")
1850     .Input("max_b: float")
1851     .Output("out: Toutput")
1852     .Output("min_out: float")
1853     .Output("max_out: float")
1854     .Attr("T1: quantizedtype")
1855     .Attr("T2: quantizedtype")
1856     .Attr("Toutput: quantizedtype = DT_QINT32")
1857     .Attr("transpose_a: bool = false")
1858     .Attr("transpose_b: bool = false")
1859     .Attr("Tactivation: quantizedtype = DT_QUINT8")
__anon99f049891202(InferenceContext* c) 1860     .SetShapeFn([](InferenceContext* c) {
1861       TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1862       ShapeHandle unused;
1863       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1864       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1865       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1866       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1867 
1868       c->set_output(1, c->Scalar());
1869       c->set_output(2, c->Scalar());
1870       return Status::OK();
1871     });
1872 
1873 // Note: This op is not commutative w.r.t. to all its inputs.
1874 REGISTER_OP("QuantizedMul")
1875     .Input("x: T1")
1876     .Input("y: T2")
1877     .Input("min_x: float")
1878     .Input("max_x: float")
1879     .Input("min_y: float")
1880     .Input("max_y: float")
1881     .Output("z: Toutput")
1882     .Output("min_z: float")
1883     .Output("max_z: float")
1884     .Attr("T1: quantizedtype")
1885     .Attr("T2: quantizedtype")
1886     .Attr("Toutput: quantizedtype = DT_QINT32")
__anon99f049891302(InferenceContext* c) 1887     .SetShapeFn([](InferenceContext* c) {
1888       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1889       c->set_output(1, c->Scalar());
1890       c->set_output(2, c->Scalar());
1891       return Status::OK();
1892     });
1893 
1894 // Note: This op is not commutative w.r.t. to all its inputs.
1895 REGISTER_OP("QuantizedAdd")
1896     .Input("x: T1")
1897     .Input("y: T2")
1898     .Input("min_x: float")
1899     .Input("max_x: float")
1900     .Input("min_y: float")
1901     .Input("max_y: float")
1902     .Output("z: Toutput")
1903     .Output("min_z: float")
1904     .Output("max_z: float")
1905     .Attr("T1: quantizedtype")
1906     .Attr("T2: quantizedtype")
1907     .Attr("Toutput: quantizedtype = DT_QINT32")
__anon99f049891402(InferenceContext* c) 1908     .SetShapeFn([](InferenceContext* c) {
1909       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1910       // min_x, max_x, min_y, max_y should be scalar.
1911       ShapeHandle unused;
1912       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1913       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1914       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1915       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1916 
1917       c->set_output(1, c->Scalar());
1918       c->set_output(2, c->Scalar());
1919       return Status::OK();
1920     });
1921 
1922 REGISTER_OP("QuantizeDownAndShrinkRange")
1923     .Input("input: Tinput")
1924     .Input("input_min: float")
1925     .Input("input_max: float")
1926     .Output("output: out_type")
1927     .Output("output_min: float")
1928     .Output("output_max: float")
1929     .Attr("Tinput: quantizedtype")
1930     .Attr("out_type: quantizedtype")
__anon99f049891502(InferenceContext* c) 1931     .SetShapeFn([](InferenceContext* c) {
1932       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1933       ShapeHandle unused;
1934       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1935       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1936       c->set_output(1, c->Scalar());
1937       c->set_output(2, c->Scalar());
1938       return Status::OK();
1939     });
1940 
1941 REGISTER_OP("Requantize")
1942     .Input("input: Tinput")
1943     .Input("input_min: float")
1944     .Input("input_max: float")
1945     .Input("requested_output_min: float")
1946     .Input("requested_output_max: float")
1947     .Output("output: out_type")
1948     .Output("output_min: float")
1949     .Output("output_max: float")
1950     .Attr("Tinput: quantizedtype")
1951     .Attr("out_type: quantizedtype")
__anon99f049891602(InferenceContext* c) 1952     .SetShapeFn([](InferenceContext* c) {
1953       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1954       ShapeHandle unused;
1955       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1956       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1957       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1958       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1959       c->set_output(1, c->Scalar());
1960       c->set_output(2, c->Scalar());
1961       return Status::OK();
1962     });
1963 
1964 REGISTER_OP("RequantizationRange")
1965     .Input("input: Tinput")
1966     .Input("input_min: float")
1967     .Input("input_max: float")
1968     .Output("output_min: float")
1969     .Output("output_max: float")
1970     .Attr("Tinput: quantizedtype")
__anon99f049891702(InferenceContext* c) 1971     .SetShapeFn([](InferenceContext* c) {
1972       ShapeHandle unused;
1973       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1974       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1975       c->set_output(0, c->Scalar());
1976       c->set_output(1, c->Scalar());
1977       return Status::OK();
1978     });
1979 
1980 // --------------------------------------------------------------------------
1981 
1982 REGISTER_OP("Bucketize")
1983     .Input("input: T")
1984     .Output("output: int32")
1985     .Attr("T: {int32, int64, float, double}")
1986     .Attr("boundaries: list(float)")
1987     .SetShapeFn(shape_inference::UnchangedShape);
1988 
1989 REGISTER_OP("ClipByValue")
1990     .Input("t: T")
1991     .Input("clip_value_min: T")
1992     .Input("clip_value_max: T")
1993     .Output("output: T")
1994     .Attr("T: numbertype")
1995     .SetShapeFn(shape_inference::UnchangedShape);
1996 
1997 #ifdef INTEL_MKL
1998 // Note: This op is not commutative w.r.t. to all its inputs.
1999 REGISTER_OP("_MklAddN")
2000     .Input("inputs: N * T")
2001     .Input("mkl_input: N * uint8")
2002     .Output("sum: T")
2003     .Output("mkl_sum: uint8")
2004     .Attr("N: int >= 1")
2005     .Attr("T: numbertype")
__anon99f049891802(InferenceContext* c) 2006     .SetShapeFn([](InferenceContext* c) {
2007       ShapeHandle cur = c->input(c->num_inputs() - 1);
2008       for (int i = c->num_inputs() - 2; i >= 0; --i) {
2009         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
2010                                         "From merging shape ", i,
2011                                         " with other shapes.");
2012       }
2013       c->set_output(0, cur);
2014       return Status::OK();
2015     })
2016     .Doc(R"doc(
2017 Add two input tensors element wise using mkl kernel sum.
2018 inputs: Must all be the same size and shape.
2019 )doc");
2020 
2021 #endif  // INTEL_MKL
2022 
2023 REGISTER_OP("RequantizePerChannel")
2024     .Input("input: T")
2025     .Input("input_min: float")
2026     .Input("input_max: float")
2027     .Input("requested_output_min: float")
2028     .Input("requested_output_max: float")
2029     .Output("output: out_type")
2030     .Output("output_min: float")
2031     .Output("output_max: float")
2032     .Attr("T: quantizedtype = DT_QINT32")
2033     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon99f049891902(InferenceContext* c) 2034     .SetShapeFn([](InferenceContext* c) {
2035       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2036       ShapeHandle unused;
2037       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2038       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2039       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2040       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2041       c->set_output(1, c->Scalar());
2042       c->set_output(2, c->Scalar());
2043       return Status::OK();
2044     });
2045 REGISTER_OP("RequantizationRangePerChannel")
2046     .Input("input: T")
2047     .Input("input_min: float")
2048     .Input("input_max: float")
2049     .Output("output_min: float")
2050     .Output("output_max: float")
2051     .Attr("T: quantizedtype = DT_QINT32")
2052     .Attr("clip_value_max: float")
__anon99f049891a02(InferenceContext* c) 2053     .SetShapeFn([](InferenceContext* c) {
2054       ShapeHandle unused;
2055       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2056       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2057       c->set_output(0, c->Scalar());
2058       c->set_output(1, c->Scalar());
2059       return Status::OK();
2060     });
2061 
2062 REGISTER_OP("NextAfter")
2063     .Attr("T: {float64, float32} = DT_FLOAT")
2064     .Input("x1: T")
2065     .Input("x2: T")
2066     .Output("output: T")
2067     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
2068 
2069 REGISTER_OP("SobolSample")
2070     .Input("dim: int32")
2071     .Input("num_results: int32")
2072     .Input("skip: int32")
2073     .Attr("dtype: {float, double} = DT_FLOAT")
2074     .Output("samples: dtype")
__anon99f049891b02(shape_inference::InferenceContext* c) 2075     .SetShapeFn([](shape_inference::InferenceContext* c) {
2076       ShapeHandle unused;
2077 
2078       // inputs must be scalars
2079       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
2080       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2081       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2082 
2083       const Tensor* dim_t = c->input_tensor(0);
2084       const Tensor* num_results_t = c->input_tensor(1);
2085 
2086       int32_t dim = dim_t == nullptr ? InferenceContext::kUnknownDim
2087                                      : dim_t->scalar<int32>()();
2088 
2089       int32_t num_results = num_results_t == nullptr
2090                                 ? InferenceContext::kUnknownDim
2091                                 : num_results_t->scalar<int32>()();
2092 
2093       c->set_output(0, c->Matrix(num_results, dim));
2094       return Status::OK();
2095     });
2096 
2097 }  // namespace tensorflow
2098